Procházet zdrojové kódy

initial axum implementation

Evan Almloff před 2 roky
rodič
revize
939e75541e

+ 14 - 1
packages/server/Cargo.toml

@@ -18,4 +18,17 @@ axum = { version = "0.6.1", optional = true, features = ["ws"] }
 salvo = { version = "0.37.7", optional = true, features = ["ws"] }
 serde = "1.0.159"
 
-dioxus = { path = "../dioxus", version = "^0.3.0" }
+dioxus = { path = "../dioxus", version = "^0.3.0" }
+
+log = "0.4.17"
+once_cell = "1.17.1"
+thiserror = "1.0.40"
+hyper = "0.14.25"
+tokio = { version = "1.27.0", features = ["full"] }
+
+[features]
+default = ["axum", "ssr"]
+warp = ["dep:warp"]
+axum = ["dep:axum"]
+salvo = ["dep:salvo"]
+ssr = ["server_fn/ssr"]

+ 100 - 18
packages/server/src/adapters/axum_adapter.rs

@@ -1,23 +1,105 @@
-use crate::{LiveViewError, LiveViewSocket};
-use axum::extract::ws::{Message, WebSocket};
-use futures_util::{SinkExt, StreamExt};
-
-/// Convert a warp websocket into a LiveViewSocket
-///
-/// This is required to launch a LiveView app using the warp web framework
-pub fn axum_socket(ws: WebSocket) -> impl LiveViewSocket {
-    ws.map(transform_rx)
-        .with(transform_tx)
-        .sink_map_err(|_| LiveViewError::SendingFailed)
+use std::{error::Error, sync::Arc};
+
+use axum::{
+    body::{self, Body, BoxBody, Full},
+    http::{HeaderMap, Request, Response, StatusCode},
+    response::IntoResponse,
+    routing::post,
+    Router,
+};
+use server_fn::{Payload, ServerFunctionRegistry};
+use tokio::task::spawn_blocking;
+
+use crate::{DioxusServerContext, DioxusServerFnRegistry, ServerFnTraitObj};
+
+trait DioxusRouterExt {
+    fn register_server_fns(self) -> Self;
+}
+
+impl DioxusRouterExt for Router {
+    fn register_server_fns(self) -> Self {
+        let mut router = self;
+        for server_fn_path in DioxusServerFnRegistry::paths_registered() {
+            let func = DioxusServerFnRegistry::get(server_fn_path).unwrap();
+            router = router.route(
+                server_fn_path,
+                post(move |headers: HeaderMap, body: Request<Body>| async move {
+                    server_fn_handler(DioxusServerContext {}, func.clone(), headers, body).await
+                    // todo!()
+                }),
+            );
+        }
+        router
+    }
 }
 
-fn transform_rx(message: Result<Message, axum::Error>) -> Result<String, LiveViewError> {
-    message
-        .map_err(|_| LiveViewError::SendingFailed)?
-        .into_text()
-        .map_err(|_| LiveViewError::SendingFailed)
+async fn server_fn_handler(
+    server_context: DioxusServerContext,
+    function: Arc<ServerFnTraitObj>,
+    headers: HeaderMap,
+    req: Request<Body>,
+) -> impl IntoResponse {
+    let (_, body) = req.into_parts();
+    let body = hyper::body::to_bytes(body).await;
+    let Ok(body)=body else {
+        return report_err(body.err().unwrap());
+    };
+
+    // Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime
+    let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
+    spawn_blocking({
+        move || {
+            tokio::runtime::Runtime::new()
+                .expect("couldn't spawn runtime")
+                .block_on(async {
+                    let resp = match function(server_context, &body).await {
+                        Ok(serialized) => {
+                            // if this is Accept: application/json then send a serialized JSON response
+                            let accept_header =
+                                headers.get("Accept").and_then(|value| value.to_str().ok());
+                            let mut res = Response::builder();
+                            if accept_header == Some("application/json")
+                                || accept_header
+                                    == Some(
+                                        "application/\
+                                                 x-www-form-urlencoded",
+                                    )
+                                || accept_header == Some("application/cbor")
+                            {
+                                res = res.status(StatusCode::OK);
+                            }
+
+                            let resp = match serialized {
+                                Payload::Binary(data) => res
+                                    .header("Content-Type", "application/cbor")
+                                    .body(body::boxed(Full::from(data))),
+                                Payload::Url(data) => res
+                                    .header(
+                                        "Content-Type",
+                                        "application/\
+                                        x-www-form-urlencoded",
+                                    )
+                                    .body(body::boxed(data)),
+                                Payload::Json(data) => res
+                                    .header("Content-Type", "application/json")
+                                    .body(body::boxed(data)),
+                            };
+
+                            resp.unwrap()
+                        }
+                        Err(e) => report_err(e),
+                    };
+
+                    resp_tx.send(resp).unwrap();
+                })
+        }
+    });
+    resp_rx.await.unwrap()
 }
 
-async fn transform_tx(message: String) -> Result<Message, axum::Error> {
-    Ok(Message::Text(message))
+fn report_err<E: Error>(e: E) -> Response<BoxBody> {
+    Response::builder()
+        .status(StatusCode::INTERNAL_SERVER_ERROR)
+        .body(body::boxed(format!("Error: {}", e)))
+        .unwrap()
 }

+ 2 - 0
packages/server/src/adapters/mod.rs

@@ -0,0 +1,2 @@
+#[cfg(feature = "axum")]
+mod axum_adapter;

+ 86 - 67
packages/server/src/lib.rs

@@ -1,85 +1,104 @@
-use dioxus::prelude::*;
-use serde::{de::DeserializeOwned, Deserializer, Serialize, Serializer};
+mod adapters;
 
-// We use deref specialization to make it possible to pass either a value that implements
-pub trait SerializeToRemoteWrapper {
-    fn serialize_to_remote<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>;
-}
+// #[server(ReadPosts, "api")]
+// async fn testing(rx: i32) -> Result<u32, ServerFnError> {
+//     Ok(0)
+// }
 
-impl<T: Serialize> SerializeToRemoteWrapper for &T {
-    fn serialize_to_remote<S: Serializer>(
-        &self,
-        serializer: S,
-    ) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error> {
-        self.serialize(serializer)
-    }
-}
+pub struct DioxusServerContext {}
 
-impl<S: SerializeToRemote> SerializeToRemoteWrapper for &mut &S {
-    fn serialize_to_remote<S2: Serializer>(
-        &self,
-        serializer: S2,
-    ) -> Result<<S2 as Serializer>::Ok, <S2 as Serializer>::Error> {
-        (**self).serialize_to_remote(serializer)
-    }
-}
+#[cfg(any(feature = "ssr", doc))]
+type ServerFnTraitObj = server_fn::ServerFnTraitObj<DioxusServerContext>;
 
-pub trait SerializeToRemote {
-    fn serialize_to_remote<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error>;
-}
+#[cfg(any(feature = "ssr", doc))]
+static REGISTERED_SERVER_FUNCTIONS: once_cell::sync::Lazy<
+    std::sync::Arc<
+        std::sync::RwLock<
+            std::collections::HashMap<&'static str, std::sync::Arc<ServerFnTraitObj>>,
+        >,
+    >,
+> = once_cell::sync::Lazy::new(Default::default);
 
-impl<S: Serialize> SerializeToRemote for UseState<S> {
-    fn serialize_to_remote<S2: Serializer>(
-        &self,
-        serializer: S2,
-    ) -> Result<<S2 as Serializer>::Ok, <S2 as Serializer>::Error> {
-        self.current().serialize(serializer)
-    }
-}
+#[cfg(any(feature = "ssr", doc))]
+/// The registry of all Dioxus server functions.
+pub struct DioxusServerFnRegistry;
 
-// We use deref specialization to make it possible to pass either a value that implements
-pub trait DeserializeOnRemoteWrapper {
-    type Output;
+#[cfg(any(feature = "ssr"))]
+impl server_fn::ServerFunctionRegistry<DioxusServerContext> for DioxusServerFnRegistry {
+    type Error = ServerRegistrationFnError;
 
-    fn deserialize_on_remote<'a, D: Deserializer<'a>>(
-        deserializer: D,
-    ) -> Result<Self::Output, D::Error>;
-}
-
-impl<T: DeserializeOwned> DeserializeOnRemoteWrapper for &T {
-    type Output = T;
+    fn register(
+        url: &'static str,
+        server_function: std::sync::Arc<ServerFnTraitObj>,
+    ) -> Result<(), Self::Error> {
+        // store it in the hashmap
+        let mut write = REGISTERED_SERVER_FUNCTIONS
+            .write()
+            .map_err(|e| ServerRegistrationFnError::Poisoned(e.to_string()))?;
+        let prev = write.insert(url, server_function);
 
-    fn deserialize_on_remote<'a, D: Deserializer<'a>>(
-        deserializer: D,
-    ) -> Result<Self::Output, D::Error> {
-        T::deserialize(deserializer)
+        // if there was already a server function with this key,
+        // return Err
+        match prev {
+            Some(_) => Err(ServerRegistrationFnError::AlreadyRegistered(format!(
+                "There was already a server function registered at {:?}. \
+                     This can happen if you use the same server function name \
+                     in two different modules
+                on `stable` or in `release` mode.",
+                url
+            ))),
+            None => Ok(()),
+        }
     }
-}
 
-impl<D: DeserializeOnRemote> DeserializeOnRemoteWrapper for &mut &D {
-    type Output = D::Output;
+    /// Returns the server function registered at the given URL, or `None` if no function is registered at that URL.
+    fn get(url: &str) -> Option<std::sync::Arc<ServerFnTraitObj>> {
+        REGISTERED_SERVER_FUNCTIONS
+            .read()
+            .ok()
+            .and_then(|fns| fns.get(url).cloned())
+    }
 
-    fn deserialize_on_remote<'a, D2: Deserializer<'a>>(
-        deserializer: D2,
-    ) -> Result<Self::Output, D2::Error> {
-        D::deserialize_on_remote(deserializer)
+    /// Returns a list of all registered server functions.
+    fn paths_registered() -> Vec<&'static str> {
+        REGISTERED_SERVER_FUNCTIONS
+            .read()
+            .ok()
+            .map(|fns| fns.keys().cloned().collect())
+            .unwrap_or_default()
     }
 }
 
-pub trait DeserializeOnRemote {
-    type Output;
-
-    fn deserialize_on_remote<'a, D: Deserializer<'a>>(
-        deserializer: D,
-    ) -> Result<Self::Output, D::Error>;
+#[cfg(any(feature = "ssr", doc))]
+/// Errors that can occur when registering a server function.
+#[derive(thiserror::Error, Debug, Clone, serde::Serialize, serde::Deserialize)]
+pub enum ServerRegistrationFnError {
+    /// The server function is already registered.
+    #[error("The server function {0} is already registered")]
+    AlreadyRegistered(String),
+    /// The server function registry is poisoned.
+    #[error("The server function registry is poisoned: {0}")]
+    Poisoned(String),
 }
 
-impl<D: DeserializeOwned> DeserializeOnRemote for UseState<D> {
-    type Output = D;
-
-    fn deserialize_on_remote<'a, D2: Deserializer<'a>>(
-        deserializer: D2,
-    ) -> Result<Self::Output, D2::Error> {
-        D::deserialize(deserializer)
+/// Defines a "server function." A server function can be called from the server or the client,
+/// but the body of its code will only be run on the server, i.e., if a crate feature `ssr` is enabled.
+///
+/// (This follows the same convention as the Dioxus framework's distinction between `ssr` for server-side rendering,
+/// and `csr` and `hydrate` for client-side rendering and hydration, respectively.)
+///
+/// Server functions are created using the `server` macro.
+///
+/// The function should be registered by calling `ServerFn::register()`. The set of server functions
+/// can be queried on the server for routing purposes by calling [server_fn_by_path].
+///
+/// Technically, the trait is implemented on a type that describes the server function's arguments.
+pub trait ServerFn: server_fn::ServerFn<DioxusServerContext> {
+    /// Registers the server function, allowing the server to query it by URL.
+    #[cfg(any(feature = "ssr", doc))]
+    fn register() -> Result<(), server_fn::ServerFnError> {
+        Self::register_in::<DioxusServerFnRegistry>()
     }
 }
+
+impl<T> ServerFn for T where T: server_fn::ServerFn<DioxusServerContext> {}