Преглед изворни кода

Fix salvo fullstack builds

Evan Almloff пре 1 година
родитељ
комит
0e83d48c04

+ 1 - 1
packages/fullstack/Cargo.toml

@@ -28,7 +28,7 @@ tower = { version = "0.4.13", features = ["util"], optional = true }
 axum-macros = "0.3.7"
 
 # salvo
-salvo = { version = "0.37.7", optional = true, features = ["serve-static", "ws", "compression"] }
+salvo = { version = "0.46.0", optional = true, features = ["serve-static", "websocket", "compression"] }
 serde = "1.0.159"
 
 # Dioxus + SSR

+ 0 - 2
packages/fullstack/examples/salvo-hello-world/src/main.rs

@@ -50,8 +50,6 @@ fn app(cx: Scope<AppProps>) -> Element {
 async fn post_server_data(data: String) -> Result<(), ServerFnError> {
     // The server context contains information about the current request and allows you to modify the response.
     let cx = server_context();
-    cx.response_headers_mut()
-        .insert("Set-Cookie", "foo=bar".parse().unwrap());
     println!("Server received: {}", data);
     println!("Request parts are {:?}", cx.request_parts());
 

+ 98 - 109
packages/fullstack/src/adapters/salvo_adapter.rs

@@ -49,21 +49,31 @@
 //! }
 //! ```
 
-use hyper::{http::HeaderValue, StatusCode};
+use http_body_util::{BodyExt, Limited};
+use hyper::body::Body as HyperBody;
+use hyper::StatusCode;
 use salvo::{
     async_trait, handler,
+    http::{
+        cookie::{Cookie, CookieJar},
+        ParseError,
+    },
     serve_static::{StaticDir, StaticFile},
-    Depot, FlowCtrl, Handler, Request, Response, Router,
+    Depot, Error as SalvoError, FlowCtrl, Handler, Request, Response, Router,
 };
-use server_fn::{Encoding, Payload, ServerFunctionRegistry};
+use server_fn::{Encoding, ServerFunctionRegistry};
 use std::error::Error;
 use std::sync::Arc;
-use tokio::task::spawn_blocking;
+use std::sync::RwLock;
 
 use crate::{
-    prelude::*, render::SSRState, serve_config::ServeConfig, server_fn::DioxusServerFnRegistry,
+    layer::Service, prelude::*, render::SSRState, serve_config::ServeConfig,
+    server_fn::DioxusServerFnRegistry, server_fn_service,
 };
 
+type HyperRequest = hyper::Request<hyper::Body>;
+type HyperResponse = hyper::Response<HyperBody>;
+
 /// A extension trait with utilities for integrating Dioxus with your Salvo router.
 pub trait DioxusRouterExt {
     /// Registers server functions with a custom handler function. This allows you to pass custom context to your server functions by generating a [`DioxusServerContext`] from the request.
@@ -297,13 +307,71 @@ impl DioxusRouterExt for Router {
 }
 
 /// Extracts the parts of a request that are needed for server functions. This will take parts of the request and replace them with empty values.
-pub fn extract_parts(req: &mut Request) -> RequestParts {
-    RequestParts {
-        method: std::mem::take(req.method_mut()),
-        uri: std::mem::take(req.uri_mut()),
-        version: req.version(),
-        headers: std::mem::take(req.headers_mut()),
-        extensions: std::mem::take(req.extensions_mut()),
+pub fn extract_parts(req: &mut Request) -> http::request::Parts {
+    let mut parts = http::request::Request::new(()).into_parts().0;
+    parts.method = std::mem::take(req.method_mut());
+    parts.uri = std::mem::take(req.uri_mut());
+    parts.version = req.version();
+    parts.headers = std::mem::take(req.headers_mut());
+    parts.extensions = std::mem::take(req.extensions_mut());
+
+    parts
+}
+
+fn apply_request_parts_to_response(
+    headers: hyper::header::HeaderMap,
+    response: &mut salvo::prelude::Response,
+) {
+    let mut_headers = response.headers_mut();
+    for (key, value) in headers.iter() {
+        mut_headers.insert(key, value.clone());
+    }
+}
+
+#[inline]
+async fn convert_request(req: &mut Request) -> Result<HyperRequest, SalvoError> {
+    let forward_url: hyper::Uri = TryFrom::try_from(req.uri()).map_err(SalvoError::other)?;
+    let mut build = hyper::Request::builder()
+        .method(req.method())
+        .uri(&forward_url);
+    for (key, value) in req.headers() {
+        build = build.header(key, value);
+    }
+    static SECURE_MAX_SIZE: usize = 64 * 1024;
+
+    let body = Limited::new(req.take_body(), SECURE_MAX_SIZE)
+        .collect()
+        .await
+        .map_err(ParseError::other)?
+        .to_bytes();
+    build.body(body.into()).map_err(SalvoError::other)
+}
+
+#[inline]
+async fn convert_response(response: HyperResponse, res: &mut Response) {
+    let (parts, body) = response.into_parts();
+    let http::response::Parts {
+        version,
+        headers,
+        status,
+        ..
+    } = parts;
+    res.status_code = Some(status);
+    res.version = version;
+    res.cookies = CookieJar::new();
+    for cookie in headers.get_all(http::header::SET_COOKIE).iter() {
+        if let Some(cookie) = cookie
+            .to_str()
+            .ok()
+            .and_then(|s| Cookie::parse(s.to_string()).ok())
+        {
+            res.cookies.add_original(cookie);
+        }
+    }
+    res.headers = headers;
+    res.version = version;
+    if let Ok(bytes) = hyper::body::to_bytes(body).await {
+        res.body = bytes.into()
     }
 }
 
@@ -328,8 +396,9 @@ impl<P: Clone + serde::Serialize + Send + Sync + 'static> Handler for SSRHandler
             depot.inject(renderer.clone());
             renderer
         };
-        let parts: Arc<RequestParts> = Arc::new(extract_parts(req));
-        let route = parts.uri.path().to_string();
+
+        let route = req.uri().path().to_string();
+        let parts: Arc<RwLock<http::request::Parts>> = Arc::new(RwLock::new(extract_parts(req)));
         let server_context = DioxusServerContext::new(parts);
 
         match renderer_pool
@@ -341,7 +410,8 @@ impl<P: Clone + serde::Serialize + Send + Sync + 'static> Handler for SSRHandler
 
                 res.write_body(html).unwrap();
 
-                *res.headers_mut() = server_context.take_response_headers();
+                let headers = server_context.response_parts().unwrap().headers.clone();
+                apply_request_parts_to_response(headers, res);
                 freshness.write(res.headers_mut());
             }
             Err(err) => {
@@ -375,95 +445,14 @@ impl ServerFnHandler {
 #[handler]
 impl ServerFnHandler {
     async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response) {
-        let Self {
-            server_context,
-            function,
-        } = self;
-
-        let query = req
-            .uri()
-            .query()
-            .unwrap_or_default()
-            .as_bytes()
-            .to_vec()
-            .into();
-        let body = hyper::body::to_bytes(req.body_mut().unwrap()).await;
-        let Ok(body)=body else {
-            handle_error(body.err().unwrap(), res);
-            return;
-        };
-        let headers = req.headers();
-        let accept_header = headers.get("Accept").cloned();
-
-        let parts = Arc::new(extract_parts(req));
-
-        // Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime
-        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
-        spawn_blocking({
-            let function = function.clone();
-            let mut server_context = server_context.clone();
-            server_context.parts = parts;
-            move || {
-                tokio::runtime::Runtime::new()
-                    .expect("couldn't spawn runtime")
-                    .block_on(async move {
-                        let data = match function.encoding() {
-                            Encoding::Url | Encoding::Cbor => &body,
-                            Encoding::GetJSON | Encoding::GetCBOR => &query,
-                        };
-                        let server_function_future = function.call((), data);
-                        let server_function_future = ProvideServerContext::new(
-                            server_function_future,
-                            server_context.clone(),
-                        );
-                        let resp = server_function_future.await;
-
-                        resp_tx.send(resp).unwrap();
-                    })
-            }
-        });
-        let result = resp_rx.await.unwrap();
-
-        // Set the headers from the server context
-        *res.headers_mut() = server_context.take_response_headers();
-
-        match result {
-            Ok(serialized) => {
-                // if this is Accept: application/json then send a serialized JSON response
-                let accept_header = accept_header.as_ref().and_then(|value| value.to_str().ok());
-                if accept_header == Some("application/json")
-                    || accept_header
-                        == Some(
-                            "application/\
-                                x-www-form-urlencoded",
-                        )
-                    || accept_header == Some("application/cbor")
-                {
-                    res.set_status_code(StatusCode::OK);
-                }
-
-                match serialized {
-                    Payload::Binary(data) => {
-                        res.headers_mut()
-                            .insert("Content-Type", HeaderValue::from_static("application/cbor"));
-                        res.write_body(data).unwrap();
-                    }
-                    Payload::Url(data) => {
-                        res.headers_mut().insert(
-                            "Content-Type",
-                            HeaderValue::from_static(
-                                "application/\
-                                    x-www-form-urlencoded",
-                            ),
-                        );
-                        res.write_body(data).unwrap();
-                    }
-                    Payload::Json(data) => {
-                        res.headers_mut()
-                            .insert("Content-Type", HeaderValue::from_static("application/json"));
-                        res.write_body(data).unwrap();
-                    }
-                }
+        match convert_request(req).await {
+            Ok(hyper_req) => {
+                let response =
+                    server_fn_service(self.server_context.clone(), self.function.clone())
+                        .run(hyper_req)
+                        .await
+                        .unwrap();
+                convert_response(response, res).await;
             }
             Err(err) => handle_error(err, res),
         }
@@ -472,7 +461,7 @@ impl ServerFnHandler {
 
 fn handle_error(error: impl Error + Send + Sync, res: &mut Response) {
     let mut resp_err = Response::new();
-    resp_err.set_status_code(StatusCode::INTERNAL_SERVER_ERROR);
+    resp_err.status_code(StatusCode::INTERNAL_SERVER_ERROR);
     resp_err.render(format!("Internal Server Error: {}", error));
     *res = resp_err;
 }
@@ -509,8 +498,8 @@ impl HotReloadHandler {
         _depot: &mut Depot,
         res: &mut Response,
     ) -> Result<(), salvo::http::StatusError> {
-        use salvo::ws::Message;
-        use salvo::ws::WebSocketUpgrade;
+        use salvo::websocket::Message;
+        use salvo::websocket::WebSocketUpgrade;
 
         let state = crate::hot_reload::spawn_hot_reload().await;
 
@@ -557,10 +546,10 @@ impl HotReloadHandler {
 #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
 #[handler]
 async fn ignore_ws(req: &mut Request, res: &mut Response) -> Result<(), salvo::http::StatusError> {
-    use salvo::ws::WebSocketUpgrade;
+    use salvo::websocket::WebSocketUpgrade;
     WebSocketUpgrade::new()
         .upgrade(req, res, |mut ws| async move {
-            let _ = ws.send(salvo::ws::Message::text("connected")).await;
+            let _ = ws.send(salvo::websocket::Message::text("connected")).await;
             while let Some(msg) = ws.recv().await {
                 if msg.is_err() {
                     return;

+ 3 - 2
packages/fullstack/src/launch.rs

@@ -161,11 +161,12 @@ pub async fn launch_server<P: Clone + serde::Serialize + Send + Sync + 'static>(
     #[cfg(all(feature = "salvo", not(feature = "axum"), not(feature = "warp")))]
     {
         use crate::adapters::salvo_adapter::DioxusRouterExt;
+        use salvo::conn::Listener;
         let router = salvo::Router::new().serve_dioxus_application("", cfg).hoop(
             salvo::compression::Compression::new()
-                .with_algos(&[salvo::prelude::CompressionAlgo::Gzip]),
+                .enable_gzip(salvo::prelude::CompressionLevel::Default),
         );
-        salvo::Server::new(salvo::listener::TcpListener::bind(addr))
+        salvo::Server::new(salvo::conn::tcp::TcpListener::new(addr).bind().await)
             .serve(router)
             .await;
     }