proxy.rs 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. use crate::{Result, WebProxyConfig};
  2. use anyhow::Context;
  3. use axum::{http::StatusCode, routing::any, Router};
  4. use hyper::{Request, Response, Uri};
  5. #[derive(Debug, Clone)]
  6. struct ProxyClient {
  7. inner: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
  8. url: Uri,
  9. }
  10. impl ProxyClient {
  11. fn new(url: Uri) -> Self {
  12. let https = hyper_rustls::HttpsConnectorBuilder::new()
  13. .with_native_roots()
  14. .https_or_http()
  15. .enable_http1()
  16. .build();
  17. Self {
  18. inner: hyper::Client::builder().build(https),
  19. url,
  20. }
  21. }
  22. async fn send(
  23. &self,
  24. mut req: Request<hyper::body::Body>,
  25. ) -> Result<Response<hyper::body::Body>> {
  26. let mut uri_parts = req.uri().clone().into_parts();
  27. uri_parts.authority = self.url.authority().cloned();
  28. uri_parts.scheme = self.url.scheme().cloned();
  29. *req.uri_mut() = Uri::from_parts(uri_parts).context("Invalid URI parts")?;
  30. self.inner
  31. .request(req)
  32. .await
  33. .map_err(crate::error::Error::ProxyRequestError)
  34. }
  35. }
  36. /// Add routes to the router handling the specified proxy config.
  37. ///
  38. /// We will proxy requests directed at either:
  39. ///
  40. /// - the exact path of the proxy config's backend URL, e.g. /api
  41. /// - the exact path with a trailing slash, e.g. /api/
  42. /// - any subpath of the backend URL, e.g. /api/foo/bar
  43. pub fn add_proxy(mut router: Router, proxy: &WebProxyConfig) -> Result<Router> {
  44. let url: Uri = proxy.backend.parse()?;
  45. let path = url.path().to_string();
  46. let client = ProxyClient::new(url);
  47. // We also match everything after the path using a wildcard matcher.
  48. let wildcard_client = client.clone();
  49. router = router.route(
  50. // Always remove trailing /'s so that the exact route
  51. // matches.
  52. path.trim_end_matches('/'),
  53. any(move |req| async move {
  54. client
  55. .send(req)
  56. .await
  57. .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
  58. }),
  59. );
  60. // Wildcard match anything else _after_ the backend URL's path.
  61. // Note that we know `path` ends with a trailing `/` in this branch,
  62. // so `wildcard` will look like `http://localhost/api/*proxywildcard`.
  63. let wildcard = format!("{}/*proxywildcard", path.trim_end_matches('/'));
  64. router = router.route(
  65. &wildcard,
  66. any(move |req| async move {
  67. wildcard_client
  68. .send(req)
  69. .await
  70. .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))
  71. }),
  72. );
  73. Ok(router)
  74. }
  75. #[cfg(test)]
  76. mod test {
  77. use super::*;
  78. use axum::{extract::Path, Router};
  79. fn setup_servers(
  80. mut config: WebProxyConfig,
  81. ) -> (
  82. tokio::task::JoinHandle<()>,
  83. tokio::task::JoinHandle<()>,
  84. String,
  85. ) {
  86. let backend_router = Router::new().route(
  87. "/*path",
  88. any(|path: Path<String>| async move { format!("backend: {}", path.0) }),
  89. );
  90. let backend_server = axum::Server::bind(&"127.0.0.1:0".parse().unwrap())
  91. .serve(backend_router.into_make_service());
  92. let backend_addr = backend_server.local_addr();
  93. let backend_handle = tokio::spawn(async move { backend_server.await.unwrap() });
  94. config.backend = format!("http://{}{}", backend_addr, config.backend);
  95. let router = super::add_proxy(Router::new(), &config);
  96. let server = axum::Server::bind(&"127.0.0.1:0".parse().unwrap())
  97. .serve(router.unwrap().into_make_service());
  98. let server_addr = server.local_addr();
  99. let server_handle = tokio::spawn(async move { server.await.unwrap() });
  100. (backend_handle, server_handle, server_addr.to_string())
  101. }
  102. async fn test_proxy_requests(path: String) {
  103. let config = WebProxyConfig {
  104. // Normally this would be an absolute URL including scheme/host/port,
  105. // but in these tests we need to let the OS choose the port so tests
  106. // don't conflict, so we'll concatenate the final address and this
  107. // path together.
  108. // So in day to day usage, use `http://localhost:8000/api` instead!
  109. backend: path,
  110. };
  111. let (backend_handle, server_handle, server_addr) = setup_servers(config);
  112. let resp = hyper::Client::new()
  113. .get(format!("http://{}/api", server_addr).parse().unwrap())
  114. .await
  115. .unwrap();
  116. assert_eq!(resp.status(), StatusCode::OK);
  117. assert_eq!(
  118. hyper::body::to_bytes(resp.into_body()).await.unwrap(),
  119. "backend: /api"
  120. );
  121. let resp = hyper::Client::new()
  122. .get(format!("http://{}/api/", server_addr).parse().unwrap())
  123. .await
  124. .unwrap();
  125. assert_eq!(resp.status(), StatusCode::OK);
  126. assert_eq!(
  127. hyper::body::to_bytes(resp.into_body()).await.unwrap(),
  128. "backend: /api/"
  129. );
  130. let resp = hyper::Client::new()
  131. .get(
  132. format!("http://{}/api/subpath", server_addr)
  133. .parse()
  134. .unwrap(),
  135. )
  136. .await
  137. .unwrap();
  138. assert_eq!(resp.status(), StatusCode::OK);
  139. assert_eq!(
  140. hyper::body::to_bytes(resp.into_body()).await.unwrap(),
  141. "backend: /api/subpath"
  142. );
  143. backend_handle.abort();
  144. server_handle.abort();
  145. }
  146. #[tokio::test]
  147. async fn add_proxy() {
  148. test_proxy_requests("/api".to_string()).await;
  149. }
  150. #[tokio::test]
  151. async fn add_proxy_trailing_slash() {
  152. test_proxy_requests("/api/".to_string()).await;
  153. }
  154. }