client_interceptors_end2end_test.cc 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244
  1. /*
  2. *
  3. * Copyright 2018 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. #include <memory>
  19. #include <vector>
  20. #include <gtest/gtest.h>
  21. #include "absl/memory/memory.h"
  22. #include <grpcpp/channel.h>
  23. #include <grpcpp/client_context.h>
  24. #include <grpcpp/create_channel.h>
  25. #include <grpcpp/create_channel_posix.h>
  26. #include <grpcpp/generic/generic_stub.h>
  27. #include <grpcpp/impl/codegen/proto_utils.h>
  28. #include <grpcpp/server.h>
  29. #include <grpcpp/server_builder.h>
  30. #include <grpcpp/server_context.h>
  31. #include <grpcpp/server_posix.h>
  32. #include <grpcpp/support/client_interceptor.h>
  33. #include "src/core/lib/iomgr/port.h"
  34. #include "src/proto/grpc/testing/echo.grpc.pb.h"
  35. #include "test/core/util/port.h"
  36. #include "test/core/util/test_config.h"
  37. #include "test/cpp/end2end/interceptors_util.h"
  38. #include "test/cpp/end2end/test_service_impl.h"
  39. #include "test/cpp/util/byte_buffer_proto_helper.h"
  40. #include "test/cpp/util/string_ref_helper.h"
  41. #ifdef GRPC_POSIX_SOCKET
  42. #include <fcntl.h>
  43. #include "src/core/lib/iomgr/socket_utils_posix.h"
  44. #endif /* GRPC_POSIX_SOCKET */
  45. namespace grpc {
  46. namespace testing {
  47. namespace {
  48. enum class RPCType {
  49. kSyncUnary,
  50. kSyncClientStreaming,
  51. kSyncServerStreaming,
  52. kSyncBidiStreaming,
  53. kAsyncCQUnary,
  54. kAsyncCQClientStreaming,
  55. kAsyncCQServerStreaming,
  56. kAsyncCQBidiStreaming,
  57. };
  58. enum class ChannelType {
  59. kHttpChannel,
  60. kFdChannel,
  61. };
  62. /* Hijacks Echo RPC and fills in the expected values */
  63. class HijackingInterceptor : public experimental::Interceptor {
  64. public:
  65. explicit HijackingInterceptor(experimental::ClientRpcInfo* info) {
  66. info_ = info;
  67. // Make sure it is the right method
  68. EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
  69. EXPECT_EQ(info->suffix_for_stats(), nullptr);
  70. EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
  71. }
  72. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  73. bool hijack = false;
  74. if (methods->QueryInterceptionHookPoint(
  75. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  76. auto* map = methods->GetSendInitialMetadata();
  77. // Check that we can see the test metadata
  78. ASSERT_EQ(map->size(), static_cast<unsigned>(1));
  79. auto iterator = map->begin();
  80. EXPECT_EQ("testkey", iterator->first);
  81. EXPECT_EQ("testvalue", iterator->second);
  82. hijack = true;
  83. }
  84. if (methods->QueryInterceptionHookPoint(
  85. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  86. EchoRequest req;
  87. auto* buffer = methods->GetSerializedSendMessage();
  88. auto copied_buffer = *buffer;
  89. EXPECT_TRUE(
  90. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  91. .ok());
  92. EXPECT_EQ(req.message(), "Hello");
  93. }
  94. if (methods->QueryInterceptionHookPoint(
  95. experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
  96. // Got nothing to do here for now
  97. }
  98. if (methods->QueryInterceptionHookPoint(
  99. experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
  100. auto* map = methods->GetRecvInitialMetadata();
  101. // Got nothing better to do here for now
  102. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  103. }
  104. if (methods->QueryInterceptionHookPoint(
  105. experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
  106. EchoResponse* resp =
  107. static_cast<EchoResponse*>(methods->GetRecvMessage());
  108. // Check that we got the hijacked message, and re-insert the expected
  109. // message
  110. EXPECT_EQ(resp->message(), "Hello1");
  111. resp->set_message("Hello");
  112. }
  113. if (methods->QueryInterceptionHookPoint(
  114. experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
  115. auto* map = methods->GetRecvTrailingMetadata();
  116. bool found = false;
  117. // Check that we received the metadata as an echo
  118. for (const auto& pair : *map) {
  119. found = pair.first.starts_with("testkey") &&
  120. pair.second.starts_with("testvalue");
  121. if (found) break;
  122. }
  123. EXPECT_EQ(found, true);
  124. auto* status = methods->GetRecvStatus();
  125. EXPECT_EQ(status->ok(), true);
  126. }
  127. if (methods->QueryInterceptionHookPoint(
  128. experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
  129. auto* map = methods->GetRecvInitialMetadata();
  130. // Got nothing better to do here at the moment
  131. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  132. }
  133. if (methods->QueryInterceptionHookPoint(
  134. experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
  135. // Insert a different message than expected
  136. EchoResponse* resp =
  137. static_cast<EchoResponse*>(methods->GetRecvMessage());
  138. resp->set_message("Hello1");
  139. }
  140. if (methods->QueryInterceptionHookPoint(
  141. experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
  142. auto* map = methods->GetRecvTrailingMetadata();
  143. // insert the metadata that we want
  144. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  145. map->insert(std::make_pair("testkey", "testvalue"));
  146. auto* status = methods->GetRecvStatus();
  147. *status = Status(StatusCode::OK, "");
  148. }
  149. if (hijack) {
  150. methods->Hijack();
  151. } else {
  152. methods->Proceed();
  153. }
  154. }
  155. private:
  156. experimental::ClientRpcInfo* info_;
  157. };
  158. class HijackingInterceptorFactory
  159. : public experimental::ClientInterceptorFactoryInterface {
  160. public:
  161. experimental::Interceptor* CreateClientInterceptor(
  162. experimental::ClientRpcInfo* info) override {
  163. return new HijackingInterceptor(info);
  164. }
  165. };
  166. class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
  167. public:
  168. explicit HijackingInterceptorMakesAnotherCall(
  169. experimental::ClientRpcInfo* info) {
  170. info_ = info;
  171. // Make sure it is the right method
  172. EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
  173. EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
  174. }
  175. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  176. if (methods->QueryInterceptionHookPoint(
  177. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  178. auto* map = methods->GetSendInitialMetadata();
  179. // Check that we can see the test metadata
  180. ASSERT_EQ(map->size(), static_cast<unsigned>(1));
  181. auto iterator = map->begin();
  182. EXPECT_EQ("testkey", iterator->first);
  183. EXPECT_EQ("testvalue", iterator->second);
  184. // Make a copy of the map
  185. metadata_map_ = *map;
  186. }
  187. if (methods->QueryInterceptionHookPoint(
  188. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  189. EchoRequest req;
  190. auto* buffer = methods->GetSerializedSendMessage();
  191. auto copied_buffer = *buffer;
  192. EXPECT_TRUE(
  193. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  194. .ok());
  195. EXPECT_EQ(req.message(), "Hello");
  196. req_ = req;
  197. stub_ = grpc::testing::EchoTestService::NewStub(
  198. methods->GetInterceptedChannel());
  199. ctx_.AddMetadata(metadata_map_.begin()->first,
  200. metadata_map_.begin()->second);
  201. stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) {
  202. EXPECT_EQ(s.ok(), true);
  203. EXPECT_EQ(resp_.message(), "Hello");
  204. methods->Hijack();
  205. });
  206. // This is a Unary RPC and we have got nothing interesting to do in the
  207. // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
  208. // return here. (We do not want to call methods->Proceed(). When the new
  209. // RPC returns, we will call methods->Hijack() instead.)
  210. return;
  211. }
  212. if (methods->QueryInterceptionHookPoint(
  213. experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
  214. // Got nothing to do here for now
  215. }
  216. if (methods->QueryInterceptionHookPoint(
  217. experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
  218. auto* map = methods->GetRecvInitialMetadata();
  219. // Got nothing better to do here for now
  220. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  221. }
  222. if (methods->QueryInterceptionHookPoint(
  223. experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
  224. EchoResponse* resp =
  225. static_cast<EchoResponse*>(methods->GetRecvMessage());
  226. // Check that we got the hijacked message, and re-insert the expected
  227. // message
  228. EXPECT_EQ(resp->message(), "Hello");
  229. }
  230. if (methods->QueryInterceptionHookPoint(
  231. experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
  232. auto* map = methods->GetRecvTrailingMetadata();
  233. bool found = false;
  234. // Check that we received the metadata as an echo
  235. for (const auto& pair : *map) {
  236. found = pair.first.starts_with("testkey") &&
  237. pair.second.starts_with("testvalue");
  238. if (found) break;
  239. }
  240. EXPECT_EQ(found, true);
  241. auto* status = methods->GetRecvStatus();
  242. EXPECT_EQ(status->ok(), true);
  243. }
  244. if (methods->QueryInterceptionHookPoint(
  245. experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
  246. auto* map = methods->GetRecvInitialMetadata();
  247. // Got nothing better to do here at the moment
  248. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  249. }
  250. if (methods->QueryInterceptionHookPoint(
  251. experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
  252. // Insert a different message than expected
  253. EchoResponse* resp =
  254. static_cast<EchoResponse*>(methods->GetRecvMessage());
  255. resp->set_message(resp_.message());
  256. }
  257. if (methods->QueryInterceptionHookPoint(
  258. experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
  259. auto* map = methods->GetRecvTrailingMetadata();
  260. // insert the metadata that we want
  261. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  262. map->insert(std::make_pair("testkey", "testvalue"));
  263. auto* status = methods->GetRecvStatus();
  264. *status = Status(StatusCode::OK, "");
  265. }
  266. methods->Proceed();
  267. }
  268. private:
  269. experimental::ClientRpcInfo* info_;
  270. std::multimap<std::string, std::string> metadata_map_;
  271. ClientContext ctx_;
  272. EchoRequest req_;
  273. EchoResponse resp_;
  274. std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
  275. };
  276. class HijackingInterceptorMakesAnotherCallFactory
  277. : public experimental::ClientInterceptorFactoryInterface {
  278. public:
  279. experimental::Interceptor* CreateClientInterceptor(
  280. experimental::ClientRpcInfo* info) override {
  281. return new HijackingInterceptorMakesAnotherCall(info);
  282. }
  283. };
  284. class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
  285. public:
  286. explicit BidiStreamingRpcHijackingInterceptor(
  287. experimental::ClientRpcInfo* info) {
  288. info_ = info;
  289. EXPECT_EQ(info->suffix_for_stats(), nullptr);
  290. }
  291. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  292. bool hijack = false;
  293. if (methods->QueryInterceptionHookPoint(
  294. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  295. CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
  296. hijack = true;
  297. }
  298. if (methods->QueryInterceptionHookPoint(
  299. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  300. EchoRequest req;
  301. auto* buffer = methods->GetSerializedSendMessage();
  302. auto copied_buffer = *buffer;
  303. EXPECT_TRUE(
  304. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  305. .ok());
  306. EXPECT_EQ(req.message().find("Hello"), 0u);
  307. msg = req.message();
  308. }
  309. if (methods->QueryInterceptionHookPoint(
  310. experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
  311. // Got nothing to do here for now
  312. }
  313. if (methods->QueryInterceptionHookPoint(
  314. experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
  315. CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
  316. "testvalue");
  317. auto* status = methods->GetRecvStatus();
  318. EXPECT_EQ(status->ok(), true);
  319. }
  320. if (methods->QueryInterceptionHookPoint(
  321. experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
  322. EchoResponse* resp =
  323. static_cast<EchoResponse*>(methods->GetRecvMessage());
  324. resp->set_message(msg);
  325. }
  326. if (methods->QueryInterceptionHookPoint(
  327. experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
  328. EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
  329. ->message()
  330. .find("Hello"),
  331. 0u);
  332. }
  333. if (methods->QueryInterceptionHookPoint(
  334. experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
  335. auto* map = methods->GetRecvTrailingMetadata();
  336. // insert the metadata that we want
  337. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  338. map->insert(std::make_pair("testkey", "testvalue"));
  339. auto* status = methods->GetRecvStatus();
  340. *status = Status(StatusCode::OK, "");
  341. }
  342. if (hijack) {
  343. methods->Hijack();
  344. } else {
  345. methods->Proceed();
  346. }
  347. }
  348. private:
  349. experimental::ClientRpcInfo* info_;
  350. std::string msg;
  351. };
  352. class ClientStreamingRpcHijackingInterceptor
  353. : public experimental::Interceptor {
  354. public:
  355. explicit ClientStreamingRpcHijackingInterceptor(
  356. experimental::ClientRpcInfo* info) {
  357. info_ = info;
  358. EXPECT_EQ(
  359. strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()),
  360. 0);
  361. EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
  362. }
  363. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  364. bool hijack = false;
  365. if (methods->QueryInterceptionHookPoint(
  366. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  367. hijack = true;
  368. }
  369. if (methods->QueryInterceptionHookPoint(
  370. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  371. if (++count_ > 10) {
  372. methods->FailHijackedSendMessage();
  373. }
  374. }
  375. if (methods->QueryInterceptionHookPoint(
  376. experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
  377. EXPECT_FALSE(got_failed_send_);
  378. got_failed_send_ = !methods->GetSendMessageStatus();
  379. }
  380. if (methods->QueryInterceptionHookPoint(
  381. experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
  382. auto* status = methods->GetRecvStatus();
  383. *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
  384. }
  385. if (hijack) {
  386. methods->Hijack();
  387. } else {
  388. methods->Proceed();
  389. }
  390. }
  391. static bool GotFailedSend() { return got_failed_send_; }
  392. private:
  393. experimental::ClientRpcInfo* info_;
  394. int count_ = 0;
  395. static bool got_failed_send_;
  396. };
  397. bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
  398. class ClientStreamingRpcHijackingInterceptorFactory
  399. : public experimental::ClientInterceptorFactoryInterface {
  400. public:
  401. experimental::Interceptor* CreateClientInterceptor(
  402. experimental::ClientRpcInfo* info) override {
  403. return new ClientStreamingRpcHijackingInterceptor(info);
  404. }
  405. };
  406. class ServerStreamingRpcHijackingInterceptor
  407. : public experimental::Interceptor {
  408. public:
  409. explicit ServerStreamingRpcHijackingInterceptor(
  410. experimental::ClientRpcInfo* info) {
  411. info_ = info;
  412. got_failed_message_ = false;
  413. EXPECT_EQ(info->suffix_for_stats(), nullptr);
  414. }
  415. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  416. bool hijack = false;
  417. if (methods->QueryInterceptionHookPoint(
  418. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  419. auto* map = methods->GetSendInitialMetadata();
  420. // Check that we can see the test metadata
  421. ASSERT_EQ(map->size(), static_cast<unsigned>(1));
  422. auto iterator = map->begin();
  423. EXPECT_EQ("testkey", iterator->first);
  424. EXPECT_EQ("testvalue", iterator->second);
  425. hijack = true;
  426. }
  427. if (methods->QueryInterceptionHookPoint(
  428. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  429. EchoRequest req;
  430. auto* buffer = methods->GetSerializedSendMessage();
  431. auto copied_buffer = *buffer;
  432. EXPECT_TRUE(
  433. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  434. .ok());
  435. EXPECT_EQ(req.message(), "Hello");
  436. }
  437. if (methods->QueryInterceptionHookPoint(
  438. experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
  439. // Got nothing to do here for now
  440. }
  441. if (methods->QueryInterceptionHookPoint(
  442. experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
  443. auto* map = methods->GetRecvTrailingMetadata();
  444. bool found = false;
  445. // Check that we received the metadata as an echo
  446. for (const auto& pair : *map) {
  447. found = pair.first.starts_with("testkey") &&
  448. pair.second.starts_with("testvalue");
  449. if (found) break;
  450. }
  451. EXPECT_EQ(found, true);
  452. auto* status = methods->GetRecvStatus();
  453. EXPECT_EQ(status->ok(), true);
  454. }
  455. if (methods->QueryInterceptionHookPoint(
  456. experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
  457. if (++count_ > 10) {
  458. methods->FailHijackedRecvMessage();
  459. }
  460. EchoResponse* resp =
  461. static_cast<EchoResponse*>(methods->GetRecvMessage());
  462. resp->set_message("Hello");
  463. }
  464. if (methods->QueryInterceptionHookPoint(
  465. experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
  466. // Only the last message will be a failure
  467. EXPECT_FALSE(got_failed_message_);
  468. got_failed_message_ = methods->GetRecvMessage() == nullptr;
  469. }
  470. if (methods->QueryInterceptionHookPoint(
  471. experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
  472. auto* map = methods->GetRecvTrailingMetadata();
  473. // insert the metadata that we want
  474. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  475. map->insert(std::make_pair("testkey", "testvalue"));
  476. auto* status = methods->GetRecvStatus();
  477. *status = Status(StatusCode::OK, "");
  478. }
  479. if (hijack) {
  480. methods->Hijack();
  481. } else {
  482. methods->Proceed();
  483. }
  484. }
  485. static bool GotFailedMessage() { return got_failed_message_; }
  486. private:
  487. experimental::ClientRpcInfo* info_;
  488. static bool got_failed_message_;
  489. int count_ = 0;
  490. };
  491. bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
  492. class ServerStreamingRpcHijackingInterceptorFactory
  493. : public experimental::ClientInterceptorFactoryInterface {
  494. public:
  495. experimental::Interceptor* CreateClientInterceptor(
  496. experimental::ClientRpcInfo* info) override {
  497. return new ServerStreamingRpcHijackingInterceptor(info);
  498. }
  499. };
  500. class BidiStreamingRpcHijackingInterceptorFactory
  501. : public experimental::ClientInterceptorFactoryInterface {
  502. public:
  503. experimental::Interceptor* CreateClientInterceptor(
  504. experimental::ClientRpcInfo* info) override {
  505. return new BidiStreamingRpcHijackingInterceptor(info);
  506. }
  507. };
  508. // The logging interceptor is for testing purposes only. It is used to verify
  509. // that all the appropriate hook points are invoked for an RPC. The counts are
  510. // reset each time a new object of LoggingInterceptor is created, so only a
  511. // single RPC should be made on the channel before calling the Verify methods.
  512. class LoggingInterceptor : public experimental::Interceptor {
  513. public:
  514. explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
  515. pre_send_initial_metadata_ = false;
  516. pre_send_message_count_ = 0;
  517. pre_send_close_ = false;
  518. post_recv_initial_metadata_ = false;
  519. post_recv_message_count_ = 0;
  520. post_recv_status_ = false;
  521. }
  522. void Intercept(experimental::InterceptorBatchMethods* methods) override {
  523. if (methods->QueryInterceptionHookPoint(
  524. experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
  525. auto* map = methods->GetSendInitialMetadata();
  526. // Check that we can see the test metadata
  527. ASSERT_EQ(map->size(), static_cast<unsigned>(1));
  528. auto iterator = map->begin();
  529. EXPECT_EQ("testkey", iterator->first);
  530. EXPECT_EQ("testvalue", iterator->second);
  531. ASSERT_FALSE(pre_send_initial_metadata_);
  532. pre_send_initial_metadata_ = true;
  533. }
  534. if (methods->QueryInterceptionHookPoint(
  535. experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
  536. EchoRequest req;
  537. auto* send_msg = methods->GetSendMessage();
  538. if (send_msg == nullptr) {
  539. // We did not get the non-serialized form of the message. Get the
  540. // serialized form.
  541. auto* buffer = methods->GetSerializedSendMessage();
  542. auto copied_buffer = *buffer;
  543. EchoRequest req;
  544. EXPECT_TRUE(
  545. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  546. .ok());
  547. EXPECT_EQ(req.message(), "Hello");
  548. } else {
  549. EXPECT_EQ(
  550. static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
  551. 0u);
  552. }
  553. auto* buffer = methods->GetSerializedSendMessage();
  554. auto copied_buffer = *buffer;
  555. EXPECT_TRUE(
  556. SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
  557. .ok());
  558. EXPECT_TRUE(req.message().find("Hello") == 0u);
  559. pre_send_message_count_++;
  560. }
  561. if (methods->QueryInterceptionHookPoint(
  562. experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
  563. // Got nothing to do here for now
  564. pre_send_close_ = true;
  565. }
  566. if (methods->QueryInterceptionHookPoint(
  567. experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
  568. auto* map = methods->GetRecvInitialMetadata();
  569. // Got nothing better to do here for now
  570. EXPECT_EQ(map->size(), static_cast<unsigned>(0));
  571. post_recv_initial_metadata_ = true;
  572. }
  573. if (methods->QueryInterceptionHookPoint(
  574. experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
  575. EchoResponse* resp =
  576. static_cast<EchoResponse*>(methods->GetRecvMessage());
  577. if (resp != nullptr) {
  578. EXPECT_TRUE(resp->message().find("Hello") == 0u);
  579. post_recv_message_count_++;
  580. }
  581. }
  582. if (methods->QueryInterceptionHookPoint(
  583. experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
  584. auto* map = methods->GetRecvTrailingMetadata();
  585. bool found = false;
  586. // Check that we received the metadata as an echo
  587. for (const auto& pair : *map) {
  588. found = pair.first.starts_with("testkey") &&
  589. pair.second.starts_with("testvalue");
  590. if (found) break;
  591. }
  592. EXPECT_EQ(found, true);
  593. auto* status = methods->GetRecvStatus();
  594. EXPECT_EQ(status->ok(), true);
  595. post_recv_status_ = true;
  596. }
  597. methods->Proceed();
  598. }
  599. static void VerifyCall(RPCType type) {
  600. switch (type) {
  601. case RPCType::kSyncUnary:
  602. case RPCType::kAsyncCQUnary:
  603. VerifyUnaryCall();
  604. break;
  605. case RPCType::kSyncClientStreaming:
  606. case RPCType::kAsyncCQClientStreaming:
  607. VerifyClientStreamingCall();
  608. break;
  609. case RPCType::kSyncServerStreaming:
  610. case RPCType::kAsyncCQServerStreaming:
  611. VerifyServerStreamingCall();
  612. break;
  613. case RPCType::kSyncBidiStreaming:
  614. case RPCType::kAsyncCQBidiStreaming:
  615. VerifyBidiStreamingCall();
  616. break;
  617. }
  618. }
  619. static void VerifyCallCommon() {
  620. EXPECT_TRUE(pre_send_initial_metadata_);
  621. EXPECT_TRUE(pre_send_close_);
  622. EXPECT_TRUE(post_recv_initial_metadata_);
  623. EXPECT_TRUE(post_recv_status_);
  624. }
  625. static void VerifyUnaryCall() {
  626. VerifyCallCommon();
  627. EXPECT_EQ(pre_send_message_count_, 1);
  628. EXPECT_EQ(post_recv_message_count_, 1);
  629. }
  630. static void VerifyClientStreamingCall() {
  631. VerifyCallCommon();
  632. EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
  633. EXPECT_EQ(post_recv_message_count_, 1);
  634. }
  635. static void VerifyServerStreamingCall() {
  636. VerifyCallCommon();
  637. EXPECT_EQ(pre_send_message_count_, 1);
  638. EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
  639. }
  640. static void VerifyBidiStreamingCall() {
  641. VerifyCallCommon();
  642. EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
  643. EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
  644. }
  645. private:
  646. static bool pre_send_initial_metadata_;
  647. static int pre_send_message_count_;
  648. static bool pre_send_close_;
  649. static bool post_recv_initial_metadata_;
  650. static int post_recv_message_count_;
  651. static bool post_recv_status_;
  652. };
  653. bool LoggingInterceptor::pre_send_initial_metadata_;
  654. int LoggingInterceptor::pre_send_message_count_;
  655. bool LoggingInterceptor::pre_send_close_;
  656. bool LoggingInterceptor::post_recv_initial_metadata_;
  657. int LoggingInterceptor::post_recv_message_count_;
  658. bool LoggingInterceptor::post_recv_status_;
  659. class LoggingInterceptorFactory
  660. : public experimental::ClientInterceptorFactoryInterface {
  661. public:
  662. experimental::Interceptor* CreateClientInterceptor(
  663. experimental::ClientRpcInfo* info) override {
  664. return new LoggingInterceptor(info);
  665. }
  666. };
  667. class TestScenario {
  668. public:
  669. explicit TestScenario(const ChannelType& channel_type,
  670. const RPCType& rpc_type)
  671. : channel_type_(channel_type), rpc_type_(rpc_type) {}
  672. ChannelType channel_type() const { return channel_type_; }
  673. RPCType rpc_type() const { return rpc_type_; }
  674. private:
  675. const ChannelType channel_type_;
  676. const RPCType rpc_type_;
  677. };
  678. std::vector<TestScenario> CreateTestScenarios() {
  679. std::vector<TestScenario> scenarios;
  680. std::vector<RPCType> rpc_types;
  681. rpc_types.emplace_back(RPCType::kSyncUnary);
  682. rpc_types.emplace_back(RPCType::kSyncClientStreaming);
  683. rpc_types.emplace_back(RPCType::kSyncServerStreaming);
  684. rpc_types.emplace_back(RPCType::kSyncBidiStreaming);
  685. rpc_types.emplace_back(RPCType::kAsyncCQUnary);
  686. rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming);
  687. for (const auto& rpc_type : rpc_types) {
  688. scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type);
  689. // TODO(yashykt): Maybe add support for non-posix sockets too
  690. #ifdef GRPC_POSIX_SOCKET
  691. scenarios.emplace_back(ChannelType::kFdChannel, rpc_type);
  692. #endif /* GRPC_POSIX_SOCKET */
  693. }
  694. return scenarios;
  695. }
  696. class ParameterizedClientInterceptorsEnd2endTest
  697. : public ::testing::TestWithParam<TestScenario> {
  698. protected:
  699. ParameterizedClientInterceptorsEnd2endTest() {
  700. ServerBuilder builder;
  701. builder.RegisterService(&service_);
  702. if (GetParam().channel_type() == ChannelType::kHttpChannel) {
  703. int port = grpc_pick_unused_port_or_die();
  704. server_address_ = "localhost:" + std::to_string(port);
  705. builder.AddListeningPort(server_address_, InsecureServerCredentials());
  706. server_ = builder.BuildAndStart();
  707. }
  708. #ifdef GRPC_POSIX_SOCKET
  709. else if (GetParam().channel_type() == ChannelType::kFdChannel) {
  710. int flags;
  711. GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_) == 0);
  712. flags = fcntl(sv_[0], F_GETFL, 0);
  713. GPR_ASSERT(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK) == 0);
  714. flags = fcntl(sv_[1], F_GETFL, 0);
  715. GPR_ASSERT(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK) == 0);
  716. GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) ==
  717. GRPC_ERROR_NONE);
  718. GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) ==
  719. GRPC_ERROR_NONE);
  720. server_ = builder.BuildAndStart();
  721. AddInsecureChannelFromFd(server_.get(), sv_[1]);
  722. }
  723. #endif /* GRPC_POSIX_SOCKET */
  724. }
  725. ~ParameterizedClientInterceptorsEnd2endTest() override {
  726. server_->Shutdown();
  727. }
  728. std::shared_ptr<grpc::Channel> CreateClientChannel(
  729. std::vector<std::unique_ptr<
  730. grpc::experimental::ClientInterceptorFactoryInterface>>
  731. creators) {
  732. if (GetParam().channel_type() == ChannelType::kHttpChannel) {
  733. return experimental::CreateCustomChannelWithInterceptors(
  734. server_address_, InsecureChannelCredentials(), ChannelArguments(),
  735. std::move(creators));
  736. }
  737. #ifdef GRPC_POSIX_SOCKET
  738. else if (GetParam().channel_type() == ChannelType::kFdChannel) {
  739. return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd(
  740. "", sv_[0], ChannelArguments(), std::move(creators));
  741. }
  742. #endif /* GRPC_POSIX_SOCKET */
  743. return nullptr;
  744. }
  745. void SendRPC(const std::shared_ptr<Channel>& channel) {
  746. switch (GetParam().rpc_type()) {
  747. case RPCType::kSyncUnary:
  748. MakeCall(channel);
  749. break;
  750. case RPCType::kSyncClientStreaming:
  751. MakeClientStreamingCall(channel);
  752. break;
  753. case RPCType::kSyncServerStreaming:
  754. MakeServerStreamingCall(channel);
  755. break;
  756. case RPCType::kSyncBidiStreaming:
  757. MakeBidiStreamingCall(channel);
  758. break;
  759. case RPCType::kAsyncCQUnary:
  760. MakeAsyncCQCall(channel);
  761. break;
  762. case RPCType::kAsyncCQClientStreaming:
  763. // TODO(yashykt) : Fill this out
  764. break;
  765. case RPCType::kAsyncCQServerStreaming:
  766. MakeAsyncCQServerStreamingCall(channel);
  767. break;
  768. case RPCType::kAsyncCQBidiStreaming:
  769. // TODO(yashykt) : Fill this out
  770. break;
  771. }
  772. }
  773. std::string server_address_;
  774. int sv_[2];
  775. EchoTestServiceStreamingImpl service_;
  776. std::unique_ptr<Server> server_;
  777. };
  778. TEST_P(ParameterizedClientInterceptorsEnd2endTest,
  779. ClientInterceptorLoggingTest) {
  780. ChannelArguments args;
  781. PhonyInterceptor::Reset();
  782. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  783. creators;
  784. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  785. // Add 20 phony interceptors
  786. for (auto i = 0; i < 20; i++) {
  787. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  788. }
  789. auto channel = CreateClientChannel(std::move(creators));
  790. SendRPC(channel);
  791. LoggingInterceptor::VerifyCall(GetParam().rpc_type());
  792. // Make sure all 20 phony interceptors were run
  793. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  794. }
  795. INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
  796. ParameterizedClientInterceptorsEnd2endTest,
  797. ::testing::ValuesIn(CreateTestScenarios()));
  798. class ClientInterceptorsEnd2endTest
  799. : public ::testing::TestWithParam<TestScenario> {
  800. protected:
  801. ClientInterceptorsEnd2endTest() {
  802. int port = grpc_pick_unused_port_or_die();
  803. ServerBuilder builder;
  804. server_address_ = "localhost:" + std::to_string(port);
  805. builder.AddListeningPort(server_address_, InsecureServerCredentials());
  806. builder.RegisterService(&service_);
  807. server_ = builder.BuildAndStart();
  808. }
  809. ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); }
  810. std::string server_address_;
  811. TestServiceImpl service_;
  812. std::unique_ptr<Server> server_;
  813. };
  814. TEST_F(ClientInterceptorsEnd2endTest,
  815. LameChannelClientInterceptorHijackingTest) {
  816. ChannelArguments args;
  817. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  818. creators;
  819. creators.push_back(absl::make_unique<HijackingInterceptorFactory>());
  820. auto channel = experimental::CreateCustomChannelWithInterceptors(
  821. server_address_, nullptr, args, std::move(creators));
  822. MakeCall(channel);
  823. }
  824. TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
  825. ChannelArguments args;
  826. PhonyInterceptor::Reset();
  827. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  828. creators;
  829. // Add 20 phony interceptors before hijacking interceptor
  830. creators.reserve(20);
  831. for (auto i = 0; i < 20; i++) {
  832. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  833. }
  834. creators.push_back(absl::make_unique<HijackingInterceptorFactory>());
  835. // Add 20 phony interceptors after hijacking interceptor
  836. for (auto i = 0; i < 20; i++) {
  837. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  838. }
  839. auto channel = experimental::CreateCustomChannelWithInterceptors(
  840. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  841. MakeCall(channel);
  842. // Make sure only 20 phony interceptors were run
  843. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  844. }
  845. TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
  846. ChannelArguments args;
  847. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  848. creators;
  849. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  850. creators.push_back(absl::make_unique<HijackingInterceptorFactory>());
  851. auto channel = experimental::CreateCustomChannelWithInterceptors(
  852. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  853. MakeCall(channel);
  854. LoggingInterceptor::VerifyUnaryCall();
  855. }
  856. TEST_F(ClientInterceptorsEnd2endTest,
  857. ClientInterceptorHijackingMakesAnotherCallTest) {
  858. ChannelArguments args;
  859. PhonyInterceptor::Reset();
  860. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  861. creators;
  862. // Add 5 phony interceptors before hijacking interceptor
  863. creators.reserve(5);
  864. for (auto i = 0; i < 5; i++) {
  865. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  866. }
  867. creators.push_back(
  868. std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
  869. new HijackingInterceptorMakesAnotherCallFactory()));
  870. // Add 7 phony interceptors after hijacking interceptor
  871. for (auto i = 0; i < 7; i++) {
  872. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  873. }
  874. auto channel = server_->experimental().InProcessChannelWithInterceptors(
  875. args, std::move(creators));
  876. MakeCall(channel, StubOptions("TestSuffixForStats"));
  877. // Make sure all interceptors were run once, since the hijacking interceptor
  878. // makes an RPC on the intercepted channel
  879. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12);
  880. }
  881. class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
  882. protected:
  883. ClientInterceptorsCallbackEnd2endTest() {
  884. int port = grpc_pick_unused_port_or_die();
  885. ServerBuilder builder;
  886. server_address_ = "localhost:" + std::to_string(port);
  887. builder.AddListeningPort(server_address_, InsecureServerCredentials());
  888. builder.RegisterService(&service_);
  889. server_ = builder.BuildAndStart();
  890. }
  891. ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); }
  892. std::string server_address_;
  893. TestServiceImpl service_;
  894. std::unique_ptr<Server> server_;
  895. };
  896. TEST_F(ClientInterceptorsCallbackEnd2endTest,
  897. ClientInterceptorLoggingTestWithCallback) {
  898. ChannelArguments args;
  899. PhonyInterceptor::Reset();
  900. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  901. creators;
  902. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  903. // Add 20 phony interceptors
  904. for (auto i = 0; i < 20; i++) {
  905. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  906. }
  907. auto channel = server_->experimental().InProcessChannelWithInterceptors(
  908. args, std::move(creators));
  909. MakeCallbackCall(channel);
  910. LoggingInterceptor::VerifyUnaryCall();
  911. // Make sure all 20 phony interceptors were run
  912. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  913. }
  914. TEST_F(ClientInterceptorsCallbackEnd2endTest,
  915. ClientInterceptorFactoryAllowsNullptrReturn) {
  916. ChannelArguments args;
  917. PhonyInterceptor::Reset();
  918. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  919. creators;
  920. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  921. // Add 20 phony interceptors and 20 null interceptors
  922. for (auto i = 0; i < 20; i++) {
  923. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  924. creators.push_back(absl::make_unique<NullInterceptorFactory>());
  925. }
  926. auto channel = server_->experimental().InProcessChannelWithInterceptors(
  927. args, std::move(creators));
  928. MakeCallbackCall(channel);
  929. LoggingInterceptor::VerifyUnaryCall();
  930. // Make sure all 20 phony interceptors were run
  931. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  932. }
  933. class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
  934. protected:
  935. ClientInterceptorsStreamingEnd2endTest() {
  936. int port = grpc_pick_unused_port_or_die();
  937. ServerBuilder builder;
  938. server_address_ = "localhost:" + std::to_string(port);
  939. builder.AddListeningPort(server_address_, InsecureServerCredentials());
  940. builder.RegisterService(&service_);
  941. server_ = builder.BuildAndStart();
  942. }
  943. ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); }
  944. std::string server_address_;
  945. EchoTestServiceStreamingImpl service_;
  946. std::unique_ptr<Server> server_;
  947. };
  948. TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
  949. ChannelArguments args;
  950. PhonyInterceptor::Reset();
  951. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  952. creators;
  953. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  954. // Add 20 phony interceptors
  955. for (auto i = 0; i < 20; i++) {
  956. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  957. }
  958. auto channel = experimental::CreateCustomChannelWithInterceptors(
  959. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  960. MakeClientStreamingCall(channel);
  961. LoggingInterceptor::VerifyClientStreamingCall();
  962. // Make sure all 20 phony interceptors were run
  963. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  964. }
  965. TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
  966. ChannelArguments args;
  967. PhonyInterceptor::Reset();
  968. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  969. creators;
  970. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  971. // Add 20 phony interceptors
  972. for (auto i = 0; i < 20; i++) {
  973. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  974. }
  975. auto channel = experimental::CreateCustomChannelWithInterceptors(
  976. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  977. MakeServerStreamingCall(channel);
  978. LoggingInterceptor::VerifyServerStreamingCall();
  979. // Make sure all 20 phony interceptors were run
  980. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  981. }
  982. TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
  983. ChannelArguments args;
  984. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  985. creators;
  986. creators.push_back(
  987. absl::make_unique<ClientStreamingRpcHijackingInterceptorFactory>());
  988. auto channel = experimental::CreateCustomChannelWithInterceptors(
  989. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  990. auto stub = grpc::testing::EchoTestService::NewStub(
  991. channel, StubOptions("TestSuffixForStats"));
  992. ClientContext ctx;
  993. EchoRequest req;
  994. EchoResponse resp;
  995. req.mutable_param()->set_echo_metadata(true);
  996. req.set_message("Hello");
  997. string expected_resp = "";
  998. auto writer = stub->RequestStream(&ctx, &resp);
  999. for (int i = 0; i < 10; i++) {
  1000. EXPECT_TRUE(writer->Write(req));
  1001. expected_resp += "Hello";
  1002. }
  1003. // The interceptor will reject the 11th message
  1004. writer->Write(req);
  1005. Status s = writer->Finish();
  1006. EXPECT_EQ(s.ok(), false);
  1007. EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
  1008. }
  1009. TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
  1010. ChannelArguments args;
  1011. PhonyInterceptor::Reset();
  1012. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1013. creators;
  1014. creators.push_back(
  1015. absl::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
  1016. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1017. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1018. MakeServerStreamingCall(channel);
  1019. EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
  1020. }
  1021. TEST_F(ClientInterceptorsStreamingEnd2endTest,
  1022. AsyncCQServerStreamingHijackingTest) {
  1023. ChannelArguments args;
  1024. PhonyInterceptor::Reset();
  1025. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1026. creators;
  1027. creators.push_back(
  1028. absl::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
  1029. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1030. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1031. MakeAsyncCQServerStreamingCall(channel);
  1032. EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
  1033. }
  1034. TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
  1035. ChannelArguments args;
  1036. PhonyInterceptor::Reset();
  1037. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1038. creators;
  1039. creators.push_back(
  1040. absl::make_unique<BidiStreamingRpcHijackingInterceptorFactory>());
  1041. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1042. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1043. MakeBidiStreamingCall(channel);
  1044. }
  1045. TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
  1046. ChannelArguments args;
  1047. PhonyInterceptor::Reset();
  1048. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1049. creators;
  1050. creators.push_back(absl::make_unique<LoggingInterceptorFactory>());
  1051. // Add 20 phony interceptors
  1052. for (auto i = 0; i < 20; i++) {
  1053. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  1054. }
  1055. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1056. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1057. MakeBidiStreamingCall(channel);
  1058. LoggingInterceptor::VerifyBidiStreamingCall();
  1059. // Make sure all 20 phony interceptors were run
  1060. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  1061. }
  1062. class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
  1063. protected:
  1064. ClientGlobalInterceptorEnd2endTest() {
  1065. int port = grpc_pick_unused_port_or_die();
  1066. ServerBuilder builder;
  1067. server_address_ = "localhost:" + std::to_string(port);
  1068. builder.AddListeningPort(server_address_, InsecureServerCredentials());
  1069. builder.RegisterService(&service_);
  1070. server_ = builder.BuildAndStart();
  1071. }
  1072. ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); }
  1073. std::string server_address_;
  1074. TestServiceImpl service_;
  1075. std::unique_ptr<Server> server_;
  1076. };
  1077. TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) {
  1078. // We should ideally be registering a global interceptor only once per
  1079. // process, but for the purposes of testing, it should be fine to modify the
  1080. // registered global interceptor when there are no ongoing gRPC operations
  1081. PhonyInterceptorFactory global_factory;
  1082. experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
  1083. ChannelArguments args;
  1084. PhonyInterceptor::Reset();
  1085. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1086. creators;
  1087. // Add 20 phony interceptors
  1088. creators.reserve(20);
  1089. for (auto i = 0; i < 20; i++) {
  1090. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  1091. }
  1092. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1093. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1094. MakeCall(channel);
  1095. // Make sure all 20 phony interceptors were run with the global interceptor
  1096. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21);
  1097. experimental::TestOnlyResetGlobalClientInterceptorFactory();
  1098. }
  1099. TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
  1100. // We should ideally be registering a global interceptor only once per
  1101. // process, but for the purposes of testing, it should be fine to modify the
  1102. // registered global interceptor when there are no ongoing gRPC operations
  1103. LoggingInterceptorFactory global_factory;
  1104. experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
  1105. ChannelArguments args;
  1106. PhonyInterceptor::Reset();
  1107. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1108. creators;
  1109. // Add 20 phony interceptors
  1110. creators.reserve(20);
  1111. for (auto i = 0; i < 20; i++) {
  1112. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  1113. }
  1114. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1115. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1116. MakeCall(channel);
  1117. LoggingInterceptor::VerifyUnaryCall();
  1118. // Make sure all 20 phony interceptors were run
  1119. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  1120. experimental::TestOnlyResetGlobalClientInterceptorFactory();
  1121. }
  1122. TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
  1123. // We should ideally be registering a global interceptor only once per
  1124. // process, but for the purposes of testing, it should be fine to modify the
  1125. // registered global interceptor when there are no ongoing gRPC operations
  1126. HijackingInterceptorFactory global_factory;
  1127. experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
  1128. ChannelArguments args;
  1129. PhonyInterceptor::Reset();
  1130. std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
  1131. creators;
  1132. // Add 20 phony interceptors
  1133. creators.reserve(20);
  1134. for (auto i = 0; i < 20; i++) {
  1135. creators.push_back(absl::make_unique<PhonyInterceptorFactory>());
  1136. }
  1137. auto channel = experimental::CreateCustomChannelWithInterceptors(
  1138. server_address_, InsecureChannelCredentials(), args, std::move(creators));
  1139. MakeCall(channel);
  1140. // Make sure all 20 phony interceptors were run
  1141. EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
  1142. experimental::TestOnlyResetGlobalClientInterceptorFactory();
  1143. }
  1144. } // namespace
  1145. } // namespace testing
  1146. } // namespace grpc
  1147. int main(int argc, char** argv) {
  1148. grpc::testing::TestEnvironment env(argc, argv);
  1149. ::testing::InitGoogleTest(&argc, argv);
  1150. return RUN_ALL_TESTS();
  1151. }