proto_file_parser.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. #include "test/cpp/util/proto_file_parser.h"
  19. #include <algorithm>
  20. #include <iostream>
  21. #include <sstream>
  22. #include <unordered_set>
  23. #include "absl/memory/memory.h"
  24. #include "absl/strings/str_split.h"
  25. #include <grpcpp/support/config.h>
  26. namespace grpc {
  27. namespace testing {
  28. namespace {
  29. // Match the user input method string to the full_name from method descriptor.
  30. bool MethodNameMatch(const std::string& full_name, const std::string& input) {
  31. std::string clean_input = input;
  32. std::replace(clean_input.begin(), clean_input.end(), '/', '.');
  33. if (clean_input.size() > full_name.size()) {
  34. return false;
  35. }
  36. return full_name.compare(full_name.size() - clean_input.size(),
  37. clean_input.size(), clean_input) == 0;
  38. }
  39. } // namespace
  40. class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector {
  41. public:
  42. explicit ErrorPrinter(ProtoFileParser* parser) : parser_(parser) {}
  43. void AddError(const std::string& filename, int line, int column,
  44. const std::string& message) override {
  45. std::ostringstream oss;
  46. oss << "error " << filename << " " << line << " " << column << " "
  47. << message << "\n";
  48. parser_->LogError(oss.str());
  49. }
  50. void AddWarning(const std::string& filename, int line, int column,
  51. const std::string& message) override {
  52. std::cerr << "warning " << filename << " " << line << " " << column << " "
  53. << message << std::endl;
  54. }
  55. private:
  56. ProtoFileParser* parser_; // not owned
  57. };
  58. ProtoFileParser::ProtoFileParser(const std::shared_ptr<grpc::Channel>& channel,
  59. const std::string& proto_path,
  60. const std::string& protofiles)
  61. : has_error_(false),
  62. dynamic_factory_(new protobuf::DynamicMessageFactory()) {
  63. std::vector<std::string> service_list;
  64. if (channel) {
  65. reflection_db_ =
  66. absl::make_unique<grpc::ProtoReflectionDescriptorDatabase>(channel);
  67. reflection_db_->GetServices(&service_list);
  68. }
  69. std::unordered_set<std::string> known_services;
  70. if (!protofiles.empty()) {
  71. for (const absl::string_view single_path : absl::StrSplit(
  72. proto_path, GRPC_CLI_PATH_SEPARATOR, absl::AllowEmpty())) {
  73. source_tree_.MapPath("", std::string(single_path));
  74. }
  75. error_printer_ = absl::make_unique<ErrorPrinter>(this);
  76. importer_ = absl::make_unique<protobuf::compiler::Importer>(
  77. &source_tree_, error_printer_.get());
  78. std::string file_name;
  79. std::stringstream ss(protofiles);
  80. while (std::getline(ss, file_name, ',')) {
  81. const auto* file_desc = importer_->Import(file_name);
  82. if (file_desc) {
  83. for (int i = 0; i < file_desc->service_count(); i++) {
  84. service_desc_list_.push_back(file_desc->service(i));
  85. known_services.insert(file_desc->service(i)->full_name());
  86. }
  87. } else {
  88. std::cerr << file_name << " not found" << std::endl;
  89. }
  90. }
  91. file_db_ =
  92. absl::make_unique<protobuf::DescriptorPoolDatabase>(*importer_->pool());
  93. }
  94. if (!reflection_db_ && !file_db_) {
  95. LogError("No available proto database");
  96. return;
  97. }
  98. if (!reflection_db_) {
  99. desc_db_ = std::move(file_db_);
  100. } else if (!file_db_) {
  101. desc_db_ = std::move(reflection_db_);
  102. } else {
  103. desc_db_ = absl::make_unique<protobuf::MergedDescriptorDatabase>(
  104. reflection_db_.get(), file_db_.get());
  105. }
  106. desc_pool_ = absl::make_unique<protobuf::DescriptorPool>(desc_db_.get());
  107. for (auto it = service_list.begin(); it != service_list.end(); it++) {
  108. if (known_services.find(*it) == known_services.end()) {
  109. if (const protobuf::ServiceDescriptor* service_desc =
  110. desc_pool_->FindServiceByName(*it)) {
  111. service_desc_list_.push_back(service_desc);
  112. known_services.insert(*it);
  113. }
  114. }
  115. }
  116. }
  117. ProtoFileParser::~ProtoFileParser() {}
  118. std::string ProtoFileParser::GetFullMethodName(const std::string& method) {
  119. has_error_ = false;
  120. if (known_methods_.find(method) != known_methods_.end()) {
  121. return known_methods_[method];
  122. }
  123. const protobuf::MethodDescriptor* method_descriptor = nullptr;
  124. for (auto it = service_desc_list_.begin(); it != service_desc_list_.end();
  125. it++) {
  126. const auto* service_desc = *it;
  127. for (int j = 0; j < service_desc->method_count(); j++) {
  128. const auto* method_desc = service_desc->method(j);
  129. if (MethodNameMatch(method_desc->full_name(), method)) {
  130. if (method_descriptor) {
  131. std::ostringstream error_stream;
  132. error_stream << "Ambiguous method names: ";
  133. error_stream << method_descriptor->full_name() << " ";
  134. error_stream << method_desc->full_name();
  135. LogError(error_stream.str());
  136. }
  137. method_descriptor = method_desc;
  138. }
  139. }
  140. }
  141. if (!method_descriptor) {
  142. LogError("Method name not found");
  143. }
  144. if (has_error_) {
  145. return "";
  146. }
  147. known_methods_[method] = method_descriptor->full_name();
  148. return method_descriptor->full_name();
  149. }
  150. std::string ProtoFileParser::GetFormattedMethodName(const std::string& method) {
  151. has_error_ = false;
  152. std::string formatted_method_name = GetFullMethodName(method);
  153. if (has_error_) {
  154. return "";
  155. }
  156. size_t last_dot = formatted_method_name.find_last_of('.');
  157. if (last_dot != std::string::npos) {
  158. formatted_method_name[last_dot] = '/';
  159. }
  160. formatted_method_name.insert(formatted_method_name.begin(), '/');
  161. return formatted_method_name;
  162. }
  163. std::string ProtoFileParser::GetMessageTypeFromMethod(const std::string& method,
  164. bool is_request) {
  165. has_error_ = false;
  166. std::string full_method_name = GetFullMethodName(method);
  167. if (has_error_) {
  168. return "";
  169. }
  170. const protobuf::MethodDescriptor* method_desc =
  171. desc_pool_->FindMethodByName(full_method_name);
  172. if (!method_desc) {
  173. LogError("Method not found");
  174. return "";
  175. }
  176. return is_request ? method_desc->input_type()->full_name()
  177. : method_desc->output_type()->full_name();
  178. }
  179. bool ProtoFileParser::IsStreaming(const std::string& method, bool is_request) {
  180. has_error_ = false;
  181. std::string full_method_name = GetFullMethodName(method);
  182. if (has_error_) {
  183. return false;
  184. }
  185. const protobuf::MethodDescriptor* method_desc =
  186. desc_pool_->FindMethodByName(full_method_name);
  187. if (!method_desc) {
  188. LogError("Method not found");
  189. return false;
  190. }
  191. return is_request ? method_desc->client_streaming()
  192. : method_desc->server_streaming();
  193. }
  194. std::string ProtoFileParser::GetSerializedProtoFromMethod(
  195. const std::string& method, const std::string& formatted_proto,
  196. bool is_request, bool is_json_format) {
  197. has_error_ = false;
  198. std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
  199. if (has_error_) {
  200. return "";
  201. }
  202. return GetSerializedProtoFromMessageType(message_type_name, formatted_proto,
  203. is_json_format);
  204. }
  205. std::string ProtoFileParser::GetFormattedStringFromMethod(
  206. const std::string& method, const std::string& serialized_proto,
  207. bool is_request, bool is_json_format) {
  208. has_error_ = false;
  209. std::string message_type_name = GetMessageTypeFromMethod(method, is_request);
  210. if (has_error_) {
  211. return "";
  212. }
  213. return GetFormattedStringFromMessageType(message_type_name, serialized_proto,
  214. is_json_format);
  215. }
  216. std::string ProtoFileParser::GetSerializedProtoFromMessageType(
  217. const std::string& message_type_name, const std::string& formatted_proto,
  218. bool is_json_format) {
  219. has_error_ = false;
  220. std::string serialized;
  221. const protobuf::Descriptor* desc =
  222. desc_pool_->FindMessageTypeByName(message_type_name);
  223. if (!desc) {
  224. LogError("Message type not found");
  225. return "";
  226. }
  227. std::unique_ptr<grpc::protobuf::Message> msg(
  228. dynamic_factory_->GetPrototype(desc)->New());
  229. bool ok;
  230. if (is_json_format) {
  231. ok = grpc::protobuf::json::JsonStringToMessage(formatted_proto, msg.get())
  232. .ok();
  233. if (!ok) {
  234. LogError("Failed to convert json format to proto.");
  235. return "";
  236. }
  237. } else {
  238. ok = protobuf::TextFormat::ParseFromString(formatted_proto, msg.get());
  239. if (!ok) {
  240. LogError("Failed to convert text format to proto.");
  241. return "";
  242. }
  243. }
  244. ok = msg->SerializeToString(&serialized);
  245. if (!ok) {
  246. LogError("Failed to serialize proto.");
  247. return "";
  248. }
  249. return serialized;
  250. }
  251. std::string ProtoFileParser::GetFormattedStringFromMessageType(
  252. const std::string& message_type_name, const std::string& serialized_proto,
  253. bool is_json_format) {
  254. has_error_ = false;
  255. const protobuf::Descriptor* desc =
  256. desc_pool_->FindMessageTypeByName(message_type_name);
  257. if (!desc) {
  258. LogError("Message type not found");
  259. return "";
  260. }
  261. std::unique_ptr<grpc::protobuf::Message> msg(
  262. dynamic_factory_->GetPrototype(desc)->New());
  263. if (!msg->ParseFromString(serialized_proto)) {
  264. LogError("Failed to deserialize proto.");
  265. return "";
  266. }
  267. std::string formatted_string;
  268. if (is_json_format) {
  269. grpc::protobuf::json::JsonPrintOptions jsonPrintOptions;
  270. jsonPrintOptions.add_whitespace = true;
  271. if (!grpc::protobuf::json::MessageToJsonString(*msg, &formatted_string,
  272. jsonPrintOptions)
  273. .ok()) {
  274. LogError("Failed to print proto message to json format");
  275. return "";
  276. }
  277. } else {
  278. if (!protobuf::TextFormat::PrintToString(*msg, &formatted_string)) {
  279. LogError("Failed to print proto message to text format");
  280. return "";
  281. }
  282. }
  283. return formatted_string;
  284. }
  285. void ProtoFileParser::LogError(const std::string& error_msg) {
  286. if (!error_msg.empty()) {
  287. std::cerr << error_msg << std::endl;
  288. }
  289. has_error_ = true;
  290. }
  291. } // namespace testing
  292. } // namespace grpc