interceptor_common.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. /*
  2. *
  3. * Copyright 2018 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
  19. #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H
  20. // IWYU pragma: private
  21. #include <array>
  22. #include <functional>
  23. #include <grpc/impl/codegen/grpc_types.h>
  24. #include <grpcpp/impl/codegen/call.h>
  25. #include <grpcpp/impl/codegen/call_op_set_interface.h>
  26. #include <grpcpp/impl/codegen/client_interceptor.h>
  27. #include <grpcpp/impl/codegen/intercepted_channel.h>
  28. #include <grpcpp/impl/codegen/server_interceptor.h>
  29. namespace grpc {
  30. namespace internal {
  31. class InterceptorBatchMethodsImpl
  32. : public experimental::InterceptorBatchMethods {
  33. public:
  34. InterceptorBatchMethodsImpl() {
  35. for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
  36. i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
  37. i = static_cast<experimental::InterceptionHookPoints>(
  38. static_cast<size_t>(i) + 1)) {
  39. hooks_[static_cast<size_t>(i)] = false;
  40. }
  41. }
  42. ~InterceptorBatchMethodsImpl() override {}
  43. bool QueryInterceptionHookPoint(
  44. experimental::InterceptionHookPoints type) override {
  45. return hooks_[static_cast<size_t>(type)];
  46. }
  47. void Proceed() override {
  48. if (call_->client_rpc_info() != nullptr) {
  49. return ProceedClient();
  50. }
  51. GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr);
  52. ProceedServer();
  53. }
  54. void Hijack() override {
  55. // Only the client can hijack when sending down initial metadata
  56. GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr &&
  57. call_->client_rpc_info() != nullptr);
  58. // It is illegal to call Hijack twice
  59. GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_);
  60. auto* rpc_info = call_->client_rpc_info();
  61. rpc_info->hijacked_ = true;
  62. rpc_info->hijacked_interceptor_ = current_interceptor_index_;
  63. ClearHookPoints();
  64. ops_->SetHijackingState();
  65. ran_hijacking_interceptor_ = true;
  66. rpc_info->RunInterceptor(this, current_interceptor_index_);
  67. }
  68. void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) {
  69. hooks_[static_cast<size_t>(type)] = true;
  70. }
  71. ByteBuffer* GetSerializedSendMessage() override {
  72. GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
  73. if (*orig_send_message_ != nullptr) {
  74. GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
  75. *orig_send_message_ = nullptr;
  76. }
  77. return send_message_;
  78. }
  79. const void* GetSendMessage() override {
  80. GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
  81. return *orig_send_message_;
  82. }
  83. void ModifySendMessage(const void* message) override {
  84. GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
  85. *orig_send_message_ = message;
  86. }
  87. bool GetSendMessageStatus() override { return !*fail_send_message_; }
  88. std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
  89. return send_initial_metadata_;
  90. }
  91. Status GetSendStatus() override {
  92. return Status(static_cast<StatusCode>(*code_), *error_message_,
  93. *error_details_);
  94. }
  95. void ModifySendStatus(const Status& status) override {
  96. *code_ = static_cast<grpc_status_code>(status.error_code());
  97. *error_details_ = status.error_details();
  98. *error_message_ = status.error_message();
  99. }
  100. std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
  101. return send_trailing_metadata_;
  102. }
  103. void* GetRecvMessage() override { return recv_message_; }
  104. std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
  105. override {
  106. return recv_initial_metadata_->map();
  107. }
  108. Status* GetRecvStatus() override { return recv_status_; }
  109. void FailHijackedSendMessage() override {
  110. GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
  111. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
  112. *fail_send_message_ = true;
  113. }
  114. std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
  115. override {
  116. return recv_trailing_metadata_->map();
  117. }
  118. void SetSendMessage(ByteBuffer* buf, const void** msg,
  119. bool* fail_send_message,
  120. std::function<Status(const void*)> serializer) {
  121. send_message_ = buf;
  122. orig_send_message_ = msg;
  123. fail_send_message_ = fail_send_message;
  124. serializer_ = serializer;
  125. }
  126. void SetSendInitialMetadata(
  127. std::multimap<std::string, std::string>* metadata) {
  128. send_initial_metadata_ = metadata;
  129. }
  130. void SetSendStatus(grpc_status_code* code, std::string* error_details,
  131. std::string* error_message) {
  132. code_ = code;
  133. error_details_ = error_details;
  134. error_message_ = error_message;
  135. }
  136. void SetSendTrailingMetadata(
  137. std::multimap<std::string, std::string>* metadata) {
  138. send_trailing_metadata_ = metadata;
  139. }
  140. void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) {
  141. recv_message_ = message;
  142. hijacked_recv_message_failed_ = hijacked_recv_message_failed;
  143. }
  144. void SetRecvInitialMetadata(MetadataMap* map) {
  145. recv_initial_metadata_ = map;
  146. }
  147. void SetRecvStatus(Status* status) { recv_status_ = status; }
  148. void SetRecvTrailingMetadata(MetadataMap* map) {
  149. recv_trailing_metadata_ = map;
  150. }
  151. std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
  152. auto* info = call_->client_rpc_info();
  153. if (info == nullptr) {
  154. return std::unique_ptr<ChannelInterface>(nullptr);
  155. }
  156. // The intercepted channel starts from the interceptor just after the
  157. // current interceptor
  158. return std::unique_ptr<ChannelInterface>(new InterceptedChannel(
  159. info->channel(), current_interceptor_index_ + 1));
  160. }
  161. void FailHijackedRecvMessage() override {
  162. GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
  163. experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
  164. *hijacked_recv_message_failed_ = true;
  165. }
  166. // Clears all state
  167. void ClearState() {
  168. reverse_ = false;
  169. ran_hijacking_interceptor_ = false;
  170. ClearHookPoints();
  171. }
  172. // Prepares for Post_recv operations
  173. void SetReverse() {
  174. reverse_ = true;
  175. ran_hijacking_interceptor_ = false;
  176. ClearHookPoints();
  177. }
  178. // This needs to be set before interceptors are run
  179. void SetCall(Call* call) { call_ = call; }
  180. // This needs to be set before interceptors are run using RunInterceptors().
  181. // Alternatively, RunInterceptors(std::function<void(void)> f) can be used.
  182. void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; }
  183. // SetCall should have been called before this.
  184. // Returns true if the interceptors list is empty
  185. bool InterceptorsListEmpty() {
  186. auto* client_rpc_info = call_->client_rpc_info();
  187. if (client_rpc_info != nullptr) {
  188. return client_rpc_info->interceptors_.empty();
  189. }
  190. auto* server_rpc_info = call_->server_rpc_info();
  191. return server_rpc_info == nullptr || server_rpc_info->interceptors_.empty();
  192. }
  193. // This should be used only by subclasses of CallOpSetInterface. SetCall and
  194. // SetCallOpSetInterface should have been called before this. After all the
  195. // interceptors are done running, either ContinueFillOpsAfterInterception or
  196. // ContinueFinalizeOpsAfterInterception will be called. Note that neither of
  197. // them is invoked if there were no interceptors registered.
  198. bool RunInterceptors() {
  199. GPR_CODEGEN_ASSERT(ops_);
  200. auto* client_rpc_info = call_->client_rpc_info();
  201. if (client_rpc_info != nullptr) {
  202. if (client_rpc_info->interceptors_.empty()) {
  203. return true;
  204. } else {
  205. RunClientInterceptors();
  206. return false;
  207. }
  208. }
  209. auto* server_rpc_info = call_->server_rpc_info();
  210. if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
  211. return true;
  212. }
  213. RunServerInterceptors();
  214. return false;
  215. }
  216. // Returns true if no interceptors are run. Returns false otherwise if there
  217. // are interceptors registered. After the interceptors are done running \a f
  218. // will be invoked. This is to be used only by BaseAsyncRequest and
  219. // SyncRequest.
  220. bool RunInterceptors(std::function<void(void)> f) {
  221. // This is used only by the server for initial call request
  222. GPR_CODEGEN_ASSERT(reverse_ == true);
  223. GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
  224. auto* server_rpc_info = call_->server_rpc_info();
  225. if (server_rpc_info == nullptr || server_rpc_info->interceptors_.empty()) {
  226. return true;
  227. }
  228. callback_ = std::move(f);
  229. RunServerInterceptors();
  230. return false;
  231. }
  232. private:
  233. void RunClientInterceptors() {
  234. auto* rpc_info = call_->client_rpc_info();
  235. if (!reverse_) {
  236. current_interceptor_index_ = 0;
  237. } else {
  238. if (rpc_info->hijacked_) {
  239. current_interceptor_index_ = rpc_info->hijacked_interceptor_;
  240. } else {
  241. current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
  242. }
  243. }
  244. rpc_info->RunInterceptor(this, current_interceptor_index_);
  245. }
  246. void RunServerInterceptors() {
  247. auto* rpc_info = call_->server_rpc_info();
  248. if (!reverse_) {
  249. current_interceptor_index_ = 0;
  250. } else {
  251. current_interceptor_index_ = rpc_info->interceptors_.size() - 1;
  252. }
  253. rpc_info->RunInterceptor(this, current_interceptor_index_);
  254. }
  255. void ProceedClient() {
  256. auto* rpc_info = call_->client_rpc_info();
  257. if (rpc_info->hijacked_ && !reverse_ &&
  258. current_interceptor_index_ == rpc_info->hijacked_interceptor_ &&
  259. !ran_hijacking_interceptor_) {
  260. // We now need to provide hijacked recv ops to this interceptor
  261. ClearHookPoints();
  262. ops_->SetHijackingState();
  263. ran_hijacking_interceptor_ = true;
  264. rpc_info->RunInterceptor(this, current_interceptor_index_);
  265. return;
  266. }
  267. if (!reverse_) {
  268. current_interceptor_index_++;
  269. // We are going down the stack of interceptors
  270. if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
  271. if (rpc_info->hijacked_ &&
  272. current_interceptor_index_ > rpc_info->hijacked_interceptor_) {
  273. // This is a hijacked RPC and we are done with hijacking
  274. ops_->ContinueFillOpsAfterInterception();
  275. } else {
  276. rpc_info->RunInterceptor(this, current_interceptor_index_);
  277. }
  278. } else {
  279. // we are done running all the interceptors without any hijacking
  280. ops_->ContinueFillOpsAfterInterception();
  281. }
  282. } else {
  283. // We are going up the stack of interceptors
  284. if (current_interceptor_index_ > 0) {
  285. // Continue running interceptors
  286. current_interceptor_index_--;
  287. rpc_info->RunInterceptor(this, current_interceptor_index_);
  288. } else {
  289. // we are done running all the interceptors without any hijacking
  290. ops_->ContinueFinalizeResultAfterInterception();
  291. }
  292. }
  293. }
  294. void ProceedServer() {
  295. auto* rpc_info = call_->server_rpc_info();
  296. if (!reverse_) {
  297. current_interceptor_index_++;
  298. if (current_interceptor_index_ < rpc_info->interceptors_.size()) {
  299. return rpc_info->RunInterceptor(this, current_interceptor_index_);
  300. } else if (ops_) {
  301. return ops_->ContinueFillOpsAfterInterception();
  302. }
  303. } else {
  304. // We are going up the stack of interceptors
  305. if (current_interceptor_index_ > 0) {
  306. // Continue running interceptors
  307. current_interceptor_index_--;
  308. return rpc_info->RunInterceptor(this, current_interceptor_index_);
  309. } else if (ops_) {
  310. return ops_->ContinueFinalizeResultAfterInterception();
  311. }
  312. }
  313. GPR_CODEGEN_ASSERT(callback_);
  314. callback_();
  315. }
  316. void ClearHookPoints() {
  317. for (auto i = static_cast<experimental::InterceptionHookPoints>(0);
  318. i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS;
  319. i = static_cast<experimental::InterceptionHookPoints>(
  320. static_cast<size_t>(i) + 1)) {
  321. hooks_[static_cast<size_t>(i)] = false;
  322. }
  323. }
  324. std::array<bool,
  325. static_cast<size_t>(
  326. experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)>
  327. hooks_;
  328. size_t current_interceptor_index_ = 0; // Current iterator
  329. bool reverse_ = false;
  330. bool ran_hijacking_interceptor_ = false;
  331. Call* call_ = nullptr; // The Call object is present along with CallOpSet
  332. // object/callback
  333. CallOpSetInterface* ops_ = nullptr;
  334. std::function<void(void)> callback_;
  335. ByteBuffer* send_message_ = nullptr;
  336. bool* fail_send_message_ = nullptr;
  337. const void** orig_send_message_ = nullptr;
  338. std::function<Status(const void*)> serializer_;
  339. std::multimap<std::string, std::string>* send_initial_metadata_;
  340. grpc_status_code* code_ = nullptr;
  341. std::string* error_details_ = nullptr;
  342. std::string* error_message_ = nullptr;
  343. std::multimap<std::string, std::string>* send_trailing_metadata_ = nullptr;
  344. void* recv_message_ = nullptr;
  345. bool* hijacked_recv_message_failed_ = nullptr;
  346. MetadataMap* recv_initial_metadata_ = nullptr;
  347. Status* recv_status_ = nullptr;
  348. MetadataMap* recv_trailing_metadata_ = nullptr;
  349. };
  350. // A special implementation of InterceptorBatchMethods to send a Cancel
  351. // notification down the interceptor stack
  352. class CancelInterceptorBatchMethods
  353. : public experimental::InterceptorBatchMethods {
  354. public:
  355. bool QueryInterceptionHookPoint(
  356. experimental::InterceptionHookPoints type) override {
  357. return type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL;
  358. }
  359. void Proceed() override {
  360. // This is a no-op. For actual continuation of the RPC simply needs to
  361. // return from the Intercept method
  362. }
  363. void Hijack() override {
  364. // Only the client can hijack when sending down initial metadata
  365. GPR_CODEGEN_ASSERT(false &&
  366. "It is illegal to call Hijack on a method which has a "
  367. "Cancel notification");
  368. }
  369. ByteBuffer* GetSerializedSendMessage() override {
  370. GPR_CODEGEN_ASSERT(false &&
  371. "It is illegal to call GetSendMessage on a method which "
  372. "has a Cancel notification");
  373. return nullptr;
  374. }
  375. bool GetSendMessageStatus() override {
  376. GPR_CODEGEN_ASSERT(
  377. false &&
  378. "It is illegal to call GetSendMessageStatus on a method which "
  379. "has a Cancel notification");
  380. return false;
  381. }
  382. const void* GetSendMessage() override {
  383. GPR_CODEGEN_ASSERT(
  384. false &&
  385. "It is illegal to call GetOriginalSendMessage on a method which "
  386. "has a Cancel notification");
  387. return nullptr;
  388. }
  389. void ModifySendMessage(const void* /*message*/) override {
  390. GPR_CODEGEN_ASSERT(
  391. false &&
  392. "It is illegal to call ModifySendMessage on a method which "
  393. "has a Cancel notification");
  394. }
  395. std::multimap<std::string, std::string>* GetSendInitialMetadata() override {
  396. GPR_CODEGEN_ASSERT(false &&
  397. "It is illegal to call GetSendInitialMetadata on a "
  398. "method which has a Cancel notification");
  399. return nullptr;
  400. }
  401. Status GetSendStatus() override {
  402. GPR_CODEGEN_ASSERT(false &&
  403. "It is illegal to call GetSendStatus on a method which "
  404. "has a Cancel notification");
  405. return Status();
  406. }
  407. void ModifySendStatus(const Status& /*status*/) override {
  408. GPR_CODEGEN_ASSERT(false &&
  409. "It is illegal to call ModifySendStatus on a method "
  410. "which has a Cancel notification");
  411. }
  412. std::multimap<std::string, std::string>* GetSendTrailingMetadata() override {
  413. GPR_CODEGEN_ASSERT(false &&
  414. "It is illegal to call GetSendTrailingMetadata on a "
  415. "method which has a Cancel notification");
  416. return nullptr;
  417. }
  418. void* GetRecvMessage() override {
  419. GPR_CODEGEN_ASSERT(false &&
  420. "It is illegal to call GetRecvMessage on a method which "
  421. "has a Cancel notification");
  422. return nullptr;
  423. }
  424. std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata()
  425. override {
  426. GPR_CODEGEN_ASSERT(false &&
  427. "It is illegal to call GetRecvInitialMetadata on a "
  428. "method which has a Cancel notification");
  429. return nullptr;
  430. }
  431. Status* GetRecvStatus() override {
  432. GPR_CODEGEN_ASSERT(false &&
  433. "It is illegal to call GetRecvStatus on a method which "
  434. "has a Cancel notification");
  435. return nullptr;
  436. }
  437. std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
  438. override {
  439. GPR_CODEGEN_ASSERT(false &&
  440. "It is illegal to call GetRecvTrailingMetadata on a "
  441. "method which has a Cancel notification");
  442. return nullptr;
  443. }
  444. std::unique_ptr<ChannelInterface> GetInterceptedChannel() override {
  445. GPR_CODEGEN_ASSERT(false &&
  446. "It is illegal to call GetInterceptedChannel on a "
  447. "method which has a Cancel notification");
  448. return std::unique_ptr<ChannelInterface>(nullptr);
  449. }
  450. void FailHijackedRecvMessage() override {
  451. GPR_CODEGEN_ASSERT(false &&
  452. "It is illegal to call FailHijackedRecvMessage on a "
  453. "method which has a Cancel notification");
  454. }
  455. void FailHijackedSendMessage() override {
  456. GPR_CODEGEN_ASSERT(false &&
  457. "It is illegal to call FailHijackedSendMessage on a "
  458. "method which has a Cancel notification");
  459. }
  460. };
  461. } // namespace internal
  462. } // namespace grpc
  463. #endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H