Explorar o código

Fix providing context to server functions (#3174)

* Fix providing context to server functions

* Fix fullstack context test
Evan Almloff hai 7 meses
pai
achega
416811ba3a

+ 40 - 0
packages/fullstack/src/serve_config.rs

@@ -7,6 +7,8 @@ use std::path::PathBuf;
 
 use dioxus_lib::prelude::dioxus_core::LaunchConfig;
 
+use crate::server::ContextProviders;
+
 /// A ServeConfig is used to configure how to serve a Dioxus application. It contains information about how to serve static assets, and what content to render with [`dioxus-ssr`].
 #[derive(Clone, Default)]
 pub struct ServeConfigBuilder {
@@ -14,6 +16,7 @@ pub struct ServeConfigBuilder {
     pub(crate) index_html: Option<String>,
     pub(crate) index_path: Option<PathBuf>,
     pub(crate) incremental: Option<dioxus_isrg::IncrementalRendererConfig>,
+    pub(crate) context_providers: ContextProviders,
 }
 
 impl LaunchConfig for ServeConfigBuilder {}
@@ -26,6 +29,7 @@ impl ServeConfigBuilder {
             index_html: None,
             index_path: None,
             incremental: None,
+            context_providers: Default::default(),
         }
     }
 
@@ -115,6 +119,40 @@ impl ServeConfigBuilder {
         self
     }
 
+    /// Provide context to the root and server functions. You can use this context
+    /// while rendering with [`consume_context`](dioxus_lib::prelude::consume_context) or in server functions with [`FromContext`](crate::prelude::FromContext).
+    ///
+    /// Context will be forwarded from the LaunchBuilder if it is provided.
+    ///
+    /// ```rust, no_run
+    /// use dioxus::prelude::*;
+    ///
+    /// dioxus::LaunchBuilder::new()
+    ///     // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
+    ///     .with_context(server_only! {
+    ///         1234567890u32
+    ///     })
+    ///     .launch(app);
+    ///
+    /// #[server]
+    /// async fn read_context() -> Result<u32, ServerFnError> {
+    ///     // You can extract values from the server context with the `extract` function
+    ///     let FromContext(value) = extract().await?;
+    ///     Ok(value)
+    /// }
+    ///
+    /// fn app() -> Element {
+    ///     let future = use_resource(read_context);
+    ///     rsx! {
+    ///         h1 { "{future:?}" }
+    ///     }
+    /// }
+    /// ```
+    pub fn context_providers(mut self, state: ContextProviders) -> Self {
+        self.context_providers = state;
+        self
+    }
+
     /// Build the ServeConfig. This may fail if the index.html file is not found.
     pub fn build(self) -> Result<ServeConfig, UnableToLoadIndex> {
         // The CLI always bundles static assets into the exe/public directory
@@ -137,6 +175,7 @@ impl ServeConfigBuilder {
         Ok(ServeConfig {
             index,
             incremental: self.incremental,
+            context_providers: self.context_providers,
         })
     }
 }
@@ -241,6 +280,7 @@ pub(crate) struct IndexHtml {
 pub struct ServeConfig {
     pub(crate) index: IndexHtml,
     pub(crate) incremental: Option<dioxus_isrg::IncrementalRendererConfig>,
+    pub(crate) context_providers: ContextProviders,
 }
 
 impl LaunchConfig for ServeConfig {}

+ 14 - 0
packages/fullstack/src/server/launch.rs

@@ -30,6 +30,20 @@ pub fn launch(
                 })
                 .unwrap_or_else(ServeConfig::new);
 
+            // Extend the config's context providers with the context providers from the launch builder
+            let platform_config = platform_config.map(|mut cfg| {
+                let mut contexts = contexts;
+                let cfg_context_providers = cfg.context_providers.clone();
+                for i in 0..cfg_context_providers.len() {
+                    contexts.push(Box::new({
+                        let cfg_context_providers = cfg_context_providers.clone();
+                        move || (cfg_context_providers[i])()
+                    }));
+                }
+                cfg.context_providers = std::sync::Arc::new(contexts);
+                cfg
+            });
+
             // Get the address the server should run on. If the CLI is running, the CLI proxies fullstack into the main address
             // and we use the generated address the CLI gives us
             let address = dioxus_cli_config::fullstack_address_or_localhost();

+ 39 - 20
packages/fullstack/src/server/mod.rs

@@ -185,19 +185,7 @@ where
         for (path, method) in server_fn::axum::server_fn_paths() {
             tracing::trace!("Registering server function: {} {}", method, path);
             let context_providers = context_providers.clone();
-            let handler = move |req| {
-                handle_server_fns_inner(
-                    path,
-                    move |server_context| {
-                        for index in 0..context_providers.len() {
-                            let context_providers = context_providers.clone();
-                            server_context
-                                .insert_boxed_factory(Box::new(move || context_providers[index]()));
-                        }
-                    },
-                    req,
-                )
-            };
+            let handler = move |req| handle_server_fns_inner(path, context_providers, req);
             self = match method {
                 Method::GET => self.route(path, get(handler)),
                 Method::POST => self.route(path, post(handler)),
@@ -258,10 +246,18 @@ where
         Cfg: TryInto<ServeConfig, Error = Error>,
         Error: std::error::Error,
     {
+        let cfg = cfg.try_into();
+        let context_providers = cfg
+            .as_ref()
+            .map(|cfg| cfg.context_providers.clone())
+            .unwrap_or_default();
+
         // Add server functions and render index.html
-        let server = self.serve_static_assets().register_server_functions();
+        let server = self
+            .serve_static_assets()
+            .register_server_functions_with_context(context_providers);
 
-        match cfg.try_into() {
+        match cfg {
             Ok(cfg) => {
                 let ssr_state = SSRState::new(&cfg);
                 server.fallback(
@@ -287,6 +283,13 @@ fn apply_request_parts_to_response<B>(
     }
 }
 
+fn add_server_context(server_context: &DioxusServerContext, context_providers: &ContextProviders) {
+    for index in 0..context_providers.len() {
+        let context_providers = context_providers.clone();
+        server_context.insert_boxed_factory(Box::new(move || context_providers[index]()));
+    }
+}
+
 /// State used by [`render_handler`] to render a dioxus component with axum
 #[derive(Clone)]
 pub struct RenderHandleState {
@@ -384,7 +387,17 @@ pub async fn render_handler(
 
     let cfg = &state.config;
     let ssr_state = state.ssr_state();
-    let build_virtual_dom = state.build_virtual_dom.clone();
+    let build_virtual_dom = {
+        let build_virtual_dom = state.build_virtual_dom.clone();
+        let context_providers = state.config.context_providers.clone();
+        move || {
+            let mut vdom = build_virtual_dom();
+            for state in context_providers.as_slice() {
+                vdom.insert_any_root_context(state());
+            }
+            vdom
+        }
+    };
 
     let (parts, _) = request.into_parts();
     let url = parts
@@ -394,10 +407,13 @@ pub async fn render_handler(
         .to_string();
     let parts: Arc<parking_lot::RwLock<http::request::Parts>> =
         Arc::new(parking_lot::RwLock::new(parts));
+    // Create the server context with info from the request
     let server_context = DioxusServerContext::from_shared_parts(parts.clone());
+    // Provide additional context from the render state
+    add_server_context(&server_context, &state.config.context_providers);
 
     match ssr_state
-        .render(url, cfg, move || build_virtual_dom(), &server_context)
+        .render(url, cfg, build_virtual_dom, &server_context)
         .await
     {
         Ok((freshness, rx)) => {
@@ -424,7 +440,7 @@ fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
 /// A handler for Dioxus server functions. This will run the server function and return the result.
 async fn handle_server_fns_inner(
     path: &str,
-    additional_context: impl Fn(&DioxusServerContext) + 'static + Clone + Send,
+    additional_context: ContextProviders,
     req: Request<Body>,
 ) -> impl IntoResponse {
     use server_fn::middleware::Service;
@@ -438,8 +454,10 @@ async fn handle_server_fns_inner(
         if let Some(mut service) =
             server_fn::axum::get_server_fn_service(&path_string)
         {
+            // Create the server context with info from the request
             let server_context = DioxusServerContext::new(parts);
-            additional_context(&server_context);
+            // Provide additional context from the render state
+            add_server_context(&server_context, &additional_context);
 
             // store Accepts and Referrer in case we need them for redirect (below)
             let accepts_html = req
@@ -451,7 +469,8 @@ async fn handle_server_fns_inner(
             let referrer = req.headers().get(REFERER).cloned();
 
             // actually run the server fn (which may use the server context)
-            let mut res = ProvideServerContext::new(service.run(req), server_context.clone()).await;
+            let fut = with_server_context(server_context.clone(), || service.run(req));
+            let mut res = ProvideServerContext::new(fut, server_context.clone()).await;
 
             // it it accepts text/html (i.e., is a plain form post) and doesn't already have a
             // Location set, then redirect to Referer

+ 8 - 0
packages/fullstack/src/server_context.rs

@@ -262,6 +262,14 @@ mod server_fn_impl {
     }
 }
 
+#[test]
+fn server_context_as_any_map() {
+    let parts = http::Request::new(()).into_parts().0;
+    let server_context = DioxusServerContext::new(parts);
+    server_context.insert_boxed_factory(Box::new(|| Box::new(1234u32)));
+    assert_eq!(server_context.get::<u32>().unwrap(), 1234u32);
+}
+
 std::thread_local! {
     pub(crate) static SERVER_CONTEXT: std::cell::RefCell<Box<DioxusServerContext>> = Default::default();
 }

+ 12 - 1
packages/playwright-tests/fullstack/src/main.rs

@@ -8,7 +8,9 @@
 use dioxus::{prelude::*, CapturedError};
 
 fn main() {
-    dioxus::launch(app);
+    dioxus::LaunchBuilder::new()
+        .with_context(1234u32)
+        .launch(app);
 }
 
 fn app() -> Element {
@@ -38,8 +40,15 @@ fn app() -> Element {
     }
 }
 
+#[cfg(feature = "server")]
+async fn assert_server_context_provided() {
+    let FromContext(i): FromContext<u32> = extract().await.unwrap();
+    assert_eq!(i, 1234u32);
+}
+
 #[server(PostServerData)]
 async fn post_server_data(data: String) -> Result<(), ServerFnError> {
+    assert_server_context_provided().await;
     println!("Server received: {}", data);
 
     Ok(())
@@ -47,11 +56,13 @@ async fn post_server_data(data: String) -> Result<(), ServerFnError> {
 
 #[server(GetServerData)]
 async fn get_server_data() -> Result<String, ServerFnError> {
+    assert_server_context_provided().await;
     Ok("Hello from the server!".to_string())
 }
 
 #[server]
 async fn server_error() -> Result<String, ServerFnError> {
+    assert_server_context_provided().await;
     tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
     Err(ServerFnError::new("the server threw an error!"))
 }