async_customized_auth_server.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright 2020 The gRPC Authors
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Server of the Python AsyncIO example of customizing authentication mechanism."""
  15. import argparse
  16. import asyncio
  17. import logging
  18. from typing import Awaitable, Callable, Tuple
  19. import _credentials
  20. import grpc
  21. helloworld_pb2, helloworld_pb2_grpc = grpc.protos_and_services(
  22. "helloworld.proto")
  23. _LOGGER = logging.getLogger(__name__)
  24. _LOGGER.setLevel(logging.INFO)
  25. _LISTEN_ADDRESS_TEMPLATE = 'localhost:%d'
  26. _SIGNATURE_HEADER_KEY = 'x-signature'
  27. class SignatureValidationInterceptor(grpc.aio.ServerInterceptor):
  28. def __init__(self):
  29. def abort(ignored_request, context: grpc.aio.ServicerContext) -> None:
  30. context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Invalid signature')
  31. self._abort_handler = grpc.unary_unary_rpc_method_handler(abort)
  32. async def intercept_service(
  33. self, continuation: Callable[[grpc.HandlerCallDetails],
  34. Awaitable[grpc.RpcMethodHandler]],
  35. handler_call_details: grpc.HandlerCallDetails
  36. ) -> grpc.RpcMethodHandler:
  37. # Example HandlerCallDetails object:
  38. # _HandlerCallDetails(
  39. # method=u'/helloworld.Greeter/SayHello',
  40. # invocation_metadata=...)
  41. method_name = handler_call_details.method.split('/')[-1]
  42. expected_metadata = (_SIGNATURE_HEADER_KEY, method_name[::-1])
  43. if expected_metadata in handler_call_details.invocation_metadata:
  44. return await continuation(handler_call_details)
  45. else:
  46. return self._abort_handler
  47. class SimpleGreeter(helloworld_pb2_grpc.GreeterServicer):
  48. async def SayHello(self, request: helloworld_pb2.HelloRequest,
  49. unused_context) -> helloworld_pb2.HelloReply:
  50. return helloworld_pb2.HelloReply(message='Hello, %s!' % request.name)
  51. async def run_server(port: int) -> Tuple[grpc.aio.Server, int]:
  52. # Bind interceptor to server
  53. server = grpc.aio.server(interceptors=(SignatureValidationInterceptor(),))
  54. helloworld_pb2_grpc.add_GreeterServicer_to_server(SimpleGreeter(), server)
  55. # Loading credentials
  56. server_credentials = grpc.ssl_server_credentials(((
  57. _credentials.SERVER_CERTIFICATE_KEY,
  58. _credentials.SERVER_CERTIFICATE,
  59. ),))
  60. # Pass down credentials
  61. port = server.add_secure_port(_LISTEN_ADDRESS_TEMPLATE % port,
  62. server_credentials)
  63. await server.start()
  64. return server, port
  65. async def main() -> None:
  66. parser = argparse.ArgumentParser()
  67. parser.add_argument('--port',
  68. nargs='?',
  69. type=int,
  70. default=50051,
  71. help='the listening port')
  72. args = parser.parse_args()
  73. server, port = await run_server(args.port)
  74. logging.info('Server is listening at port :%d', port)
  75. await server.wait_for_termination()
  76. if __name__ == '__main__':
  77. logging.basicConfig(level=logging.INFO)
  78. asyncio.run(main())