浏览代码

Fix dx serve proxying of websocket connections established from the browser (#3895)

* better ws request detection

* proxy ws connections

* improve naming

* perform case-insensitive header comparison

* preserve request headers when proxying websocket connection

* avoid side effects in conversion function

* fix websocket proxying for axum 0.8

* remove redundant conversions

---------

Co-authored-by: Evan Almloff <evanalmloff@gmail.com>
Arvid Fahlström Myrman 1 月之前
父节点
当前提交
8ca8b042af
共有 4 个文件被更改,包括 147 次插入19 次删除
  1. 1 0
      packages/cli/Cargo.toml
  2. 1 0
      packages/cli/src/serve/mod.rs
  3. 7 19
      packages/cli/src/serve/proxy.rs
  4. 138 0
      packages/cli/src/serve/proxy_ws.rs

+ 1 - 0
packages/cli/Cargo.toml

@@ -41,6 +41,7 @@ html_parser = { workspace = true }
 cargo_metadata = { workspace = true }
 tokio = { workspace = true, features = ["full"] }
 tokio-stream = { workspace = true }
+tokio-tungstenite = { workspace = true }
 chrono = { workspace = true }
 anyhow = { workspace = true }
 hyper = { workspace = true }

+ 1 - 0
packages/cli/src/serve/mod.rs

@@ -3,6 +3,7 @@ use crate::{AppBuilder, BuildId, BuildMode, BuilderUpdate, Result, ServeArgs, Tr
 mod ansi_buffer;
 mod output;
 mod proxy;
+mod proxy_ws;
 mod runner;
 mod server;
 mod update;

+ 7 - 19
packages/cli/src/serve/proxy.rs

@@ -4,12 +4,14 @@ use crate::{Error, Result};
 
 use anyhow::{anyhow, Context};
 use axum::body::Body;
+use axum::http::request::Parts;
 use axum::{body::Body as MyBody, response::IntoResponse};
 use axum::{
     http::StatusCode,
     routing::{any, MethodRouter},
     Router,
 };
+use hyper::header::*;
 use hyper::{Request, Response, Uri};
 use hyper_util::{
     client::legacy::{self, connect::HttpConnector},
@@ -99,7 +101,7 @@ pub(crate) fn proxy_to(
 ) -> MethodRouter {
     let client = ProxyClient::new(url.clone());
 
-    any(move |mut req: Request<MyBody>| async move {
+    any(move |parts: Parts, mut req: Request<MyBody>| async move {
         // Prevent request loops
         if req.headers().get("x-proxied-by-dioxus").is_some() {
             return Err(Response::builder()
@@ -115,26 +117,12 @@ pub(crate) fn proxy_to(
             "true".parse().expect("header value is valid"),
         );
 
-        // We have to throw a redirect for ws connections since the upgrade handler will not be called
-        // Our _dioxus handler will override this in the default case
+        let upgrade = req.headers().get(UPGRADE);
         if req.uri().scheme().map(|f| f.as_str()) == Some("ws")
             || req.uri().scheme().map(|f| f.as_str()) == Some("wss")
+            || upgrade.is_some_and(|h| h.as_bytes().eq_ignore_ascii_case(b"websocket"))
         {
-            let new_host = url.host().unwrap_or("localhost");
-            let proxied_uri = format!(
-                "{scheme}://{host}:{port}{path_and_query}",
-                scheme = req.uri().scheme_str().unwrap_or("ws"),
-                port = url.port().unwrap(),
-                host = new_host,
-                path_and_query = req
-                    .uri()
-                    .path_and_query()
-                    .map(|f| f.to_string())
-                    .unwrap_or_default()
-            );
-            tracing::info!(dx_src = ?TraceSrc::Dev, "Proxied websocket request {req:?} to {proxied_uri}");
-
-            return Ok(axum::response::Redirect::permanent(&proxied_uri).into_response());
+            return super::proxy_ws::proxy_websocket(parts, req, &url).await;
         }
 
         if nocache {
@@ -169,7 +157,7 @@ pub(crate) fn proxy_to(
     })
 }
 
-fn handle_proxy_error(e: Error) -> axum::http::Response<axum::body::Body> {
+pub(crate) fn handle_proxy_error(e: Error) -> axum::http::Response<axum::body::Body> {
     tracing::error!(dx_src = ?TraceSrc::Dev, "Proxy error: {}", e);
     axum::http::Response::builder()
         .status(axum::http::StatusCode::INTERNAL_SERVER_ERROR)

+ 138 - 0
packages/cli/src/serve/proxy_ws.rs

@@ -0,0 +1,138 @@
+use crate::logging::TraceSrc;
+use crate::serve::proxy::handle_proxy_error;
+use anyhow::Context;
+use axum::body::Body;
+use axum::extract::ws::{CloseFrame as ClientCloseFrame, Message as ClientMessage};
+use axum::extract::{FromRequestParts, WebSocketUpgrade};
+use axum::http::request::Parts;
+use axum::response::IntoResponse;
+use futures_util::{SinkExt, StreamExt};
+use hyper::{Request, Response, Uri};
+use tokio_tungstenite::tungstenite::protocol::{
+    CloseFrame as ServerCloseFrame, Message as ServerMessage,
+};
+
+pub(crate) async fn proxy_websocket(
+    mut parts: Parts,
+    req: Request<Body>,
+    backend_url: &Uri,
+) -> Result<Response<Body>, Response<Body>> {
+    let ws = WebSocketUpgrade::from_request_parts(&mut parts, &())
+        .await
+        .map_err(IntoResponse::into_response)?;
+
+    tracing::info!(dx_src = ?TraceSrc::Dev, "Proxying websocket connection {req:?}");
+    let proxied_request = into_proxied_request(req, backend_url).map_err(handle_proxy_error)?;
+    tracing::info!(dx_src = ?TraceSrc::Dev, "Connection proxied to {proxied_uri}", proxied_uri = proxied_request.uri());
+
+    Ok(ws.on_upgrade(move |client_ws| async move {
+        match handle_ws_connection(client_ws, proxied_request).await {
+            Ok(()) => tracing::info!(dx_src = ?TraceSrc::Dev, "Websocket connection closed"),
+            Err(e) => {
+                tracing::error!(dx_src = ?TraceSrc::Dev, "Error proxying websocket connection: {e}")
+            }
+        }
+    }))
+}
+
+fn into_proxied_request(
+    req: Request<Body>,
+    backend_url: &Uri,
+) -> crate::Result<tokio_tungstenite::tungstenite::handshake::client::Request> {
+    // ensure headers from original request are preserved
+    let (mut request_parts, _) = req.into_parts();
+    let mut uri_parts = request_parts.uri.into_parts();
+    uri_parts.scheme = uri_parts.scheme.or("ws".parse().ok());
+    uri_parts.authority = backend_url.authority().cloned();
+    request_parts.uri = Uri::from_parts(uri_parts).context("Could not construct proxy URI")?;
+    Ok(Request::from_parts(request_parts, ()))
+}
+
+#[derive(thiserror::Error, Debug)]
+enum WsError {
+    #[error("Error connecting to server: {0}")]
+    Connect(tokio_tungstenite::tungstenite::Error),
+    #[error("Error sending message to server: {0}")]
+    ToServer(tokio_tungstenite::tungstenite::Error),
+    #[error("Error receiving message from server: {0}")]
+    FromServer(tokio_tungstenite::tungstenite::Error),
+    #[error("Error sending message to client: {0}")]
+    ToClient(axum::Error),
+    #[error("Error receiving message from client: {0}")]
+    FromClient(axum::Error),
+}
+
+async fn handle_ws_connection(
+    mut client_ws: axum::extract::ws::WebSocket,
+    proxied_request: tokio_tungstenite::tungstenite::handshake::client::Request,
+) -> Result<(), WsError> {
+    let (mut server_ws, _) = tokio_tungstenite::connect_async(proxied_request)
+        .await
+        .map_err(WsError::Connect)?;
+
+    let mut closed = false;
+    while !closed {
+        tokio::select! {
+            Some(server_msg) = server_ws.next() => {
+                closed = matches!(server_msg, Ok(ServerMessage::Close(..)));
+                match server_msg.map_err(WsError::FromServer)?.into_msg() {
+                    Ok(msg) => client_ws.send(msg).await.map_err(WsError::ToClient)?,
+                    Err(UnexpectedRawFrame) => tracing::warn!(dx_src = ?TraceSrc::Dev, "Dropping unexpected raw websocket frame"),
+                }
+            },
+            Some(client_msg) = client_ws.next() => {
+                closed = matches!(client_msg, Ok(ClientMessage::Close(..)));
+                let Ok(msg) = client_msg.map_err(WsError::FromClient)?.into_msg();
+                server_ws.send(msg).await.map_err(WsError::ToServer)?;
+            },
+            else => break,
+        }
+    }
+
+    Ok(())
+}
+
+trait IntoMsg<T> {
+    type Error;
+    fn into_msg(self) -> Result<T, Self::Error>;
+}
+
+impl IntoMsg<ServerMessage> for ClientMessage {
+    type Error = std::convert::Infallible;
+    fn into_msg(self) -> Result<ServerMessage, Self::Error> {
+        use ServerMessage as SM;
+        Ok(match self {
+            Self::Text(v) => SM::Text(v.as_str().into()),
+            Self::Binary(v) => SM::Binary(v),
+            Self::Ping(v) => SM::Ping(v),
+            Self::Pong(v) => SM::Pong(v),
+            Self::Close(v) => SM::Close(v.map(|cf| ServerCloseFrame {
+                code: cf.code.into(),
+                reason: cf.reason.as_str().into(),
+            })),
+        })
+    }
+}
+
+struct UnexpectedRawFrame;
+impl IntoMsg<ClientMessage> for ServerMessage {
+    type Error = UnexpectedRawFrame;
+    fn into_msg(self) -> Result<ClientMessage, Self::Error> {
+        use ClientMessage as CM;
+        Ok(match self {
+            Self::Text(v) => CM::Text(v.as_str().into()),
+            Self::Binary(v) => CM::Binary(v),
+            Self::Ping(v) => CM::Ping(v),
+            Self::Pong(v) => CM::Pong(v),
+            Self::Close(v) => CM::Close(v.map(|cf| ClientCloseFrame {
+                code: cf.code.into(),
+                reason: cf.reason.as_str().into(),
+            })),
+            Self::Frame(_) => {
+                // this variant should never be returned by next(), but handle it
+                // gracefully by dropping it instead of panicking out of an abundance of caution
+                return Err(UnexpectedRawFrame);
+            }
+        })
+    }
+}