use parking_lot::RwLock; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; type SendSyncAnyMap = std::collections::HashMap; /// A shared context for server functions that contains information about the request and middleware state. /// /// You should not construct this directly inside components or server functions. Instead use [`server_context()`] to get the server context from the current request. /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn read_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// let headers: http::HeaderMap = server_context.extract().await?; /// println!("{:?}", headers); /// Ok(()) /// } /// ``` #[derive(Clone)] pub struct DioxusServerContext { shared_context: std::sync::Arc>, response_parts: std::sync::Arc>, pub(crate) parts: Arc>, } enum ContextType { Factory(Box Box + Send + Sync>), Value(Box), } impl ContextType { fn downcast(&self) -> Option { match self { ContextType::Value(value) => value.downcast_ref::().cloned(), ContextType::Factory(factory) => factory().downcast::().ok().map(|v| *v), } } } #[allow(clippy::derivable_impls)] impl Default for DioxusServerContext { fn default() -> Self { Self { shared_context: std::sync::Arc::new(RwLock::new(HashMap::new())), response_parts: std::sync::Arc::new(RwLock::new( http::response::Response::new(()).into_parts().0, )), parts: std::sync::Arc::new(RwLock::new(http::request::Request::new(()).into_parts().0)), } } } mod server_fn_impl { use super::*; use parking_lot::{RwLockReadGuard, RwLockWriteGuard}; use std::any::{Any, TypeId}; impl DioxusServerContext { /// Create a new server context from a request pub fn new(parts: http::request::Parts) -> Self { Self { parts: Arc::new(RwLock::new(parts)), shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())), response_parts: std::sync::Arc::new(RwLock::new( http::response::Response::new(()).into_parts().0, )), } } /// Create a server context from a shared parts #[allow(unused)] pub(crate) fn from_shared_parts(parts: Arc>) -> Self { Self { parts, shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())), response_parts: std::sync::Arc::new(RwLock::new( http::response::Response::new(()).into_parts().0, )), } } /// Clone a value from the shared server context. If you are using [`DioxusRouterExt`](crate::prelude::DioxusRouterExt), any values you insert into /// the launch context will also be available in the server context. /// /// Example: /// ```rust, no_run /// use dioxus::prelude::*; /// /// LaunchBuilder::new() /// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder /// .with_context(server_only! { /// 1234567890u32 /// }) /// .launch(app); /// /// #[server] /// async fn read_context() -> Result { /// // You can extract values from the server context with the `extract` function /// let FromContext(value) = extract().await?; /// Ok(value) /// } /// /// fn app() -> Element { /// let future = use_resource(read_context); /// rsx! { /// h1 { "{future:?}" } /// } /// } /// ``` pub fn get(&self) -> Option { self.shared_context .read() .get(&TypeId::of::()) .map(|v| v.downcast::().unwrap()) } /// Insert a value into the shared server context pub fn insert(&self, value: T) { self.insert_any(Box::new(value)); } /// Insert a boxed `Any` value into the shared server context pub fn insert_any(&self, value: Box) { self.shared_context .write() .insert((*value).type_id(), ContextType::Value(value)); } /// Insert a factory that creates a non-sync value for the shared server context pub fn insert_factory(&self, value: F) where F: Fn() -> T + Send + Sync + 'static, T: 'static, { self.shared_context.write().insert( TypeId::of::(), ContextType::Factory(Box::new(move || Box::new(value()))), ); } /// Insert a boxed factory that creates a non-sync value for the shared server context pub fn insert_boxed_factory(&self, value: Box Box + Send + Sync>) { self.shared_context .write() .insert((*value()).type_id(), ContextType::Factory(value)); } /// Get the response parts from the server context /// #[doc = include_str!("../docs/request_origin.md")] /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn set_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// let cookies = server_context.response_parts() /// .headers() /// .get("Cookie") /// .ok_or_else(|| ServerFnError::msg("failed to find Cookie header in the response"))?; /// println!("{:?}", cookies); /// Ok(()) /// } /// ``` pub fn response_parts(&self) -> RwLockReadGuard<'_, http::response::Parts> { self.response_parts.read() } /// Get the response parts from the server context /// #[doc = include_str!("../docs/request_origin.md")] /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn set_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// server_context.response_parts_mut() /// .headers_mut() /// .insert("Cookie", "dioxus=fullstack"); /// Ok(()) /// } /// ``` pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> { self.response_parts.write() } /// Get the request parts /// #[doc = include_str!("../docs/request_origin.md")] /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn read_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// let id: &i32 = server_context.request_parts() /// .extensions /// .get() /// .ok_or_else(|| ServerFnError::msg("failed to find i32 extension in the request"))?; /// println!("{:?}", id); /// Ok(()) /// } /// ``` pub fn request_parts(&self) -> parking_lot::RwLockReadGuard<'_, http::request::Parts> { self.parts.read() } /// Get the request parts mutably /// #[doc = include_str!("../docs/request_origin.md")] /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn read_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// let id: i32 = server_context.request_parts_mut() /// .extensions /// .remove() /// .ok_or_else(|| ServerFnError::msg("failed to find i32 extension in the request"))?; /// println!("{:?}", id); /// Ok(()) /// } /// ``` pub fn request_parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> { self.parts.write() } /// Extract part of the request. /// #[doc = include_str!("../docs/request_origin.md")] /// /// # Example /// /// ```rust, no_run /// # use dioxus::prelude::*; /// #[server] /// async fn read_headers() -> Result<(), ServerFnError> { /// let server_context = server_context(); /// let headers: http::HeaderMap = server_context.extract().await?; /// println!("{:?}", headers); /// Ok(()) /// } /// ``` pub async fn extract>(&self) -> Result { T::from_request(self).await } } } std::thread_local! { pub(crate) static SERVER_CONTEXT: std::cell::RefCell> = Default::default(); } /// Get information about the current server request. /// /// This function will only provide the current server context if it is called from a server function or on the server rendering a request. pub fn server_context() -> DioxusServerContext { SERVER_CONTEXT.with(|ctx| *ctx.borrow().clone()) } /// Extract some part from the current server request. /// /// This function will only provide the current server context if it is called from a server function or on the server rendering a request. pub async fn extract, I>() -> Result { E::from_request(&server_context()).await } /// Run a function inside of the server context. pub fn with_server_context(context: DioxusServerContext, f: impl FnOnce() -> O) -> O { // before polling the future, we need to set the context let prev_context = SERVER_CONTEXT.with(|ctx| ctx.replace(Box::new(context))); // poll the future, which may call server_context() let result = f(); // after polling the future, we need to restore the context SERVER_CONTEXT.with(|ctx| ctx.replace(prev_context)); result } /// A future that provides the server context to the inner future #[pin_project::pin_project] pub struct ProvideServerContext { context: DioxusServerContext, #[pin] f: F, } impl ProvideServerContext { /// Create a new future that provides the server context to the inner future pub fn new(f: F, context: DioxusServerContext) -> Self { Self { f, context } } } impl std::future::Future for ProvideServerContext { type Output = F::Output; fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { let this = self.project(); let context = this.context.clone(); with_server_context(context, || this.f.poll(cx)) } } /// A trait for extracting types from the server context #[async_trait::async_trait] pub trait FromServerContext: Sized { /// The error type returned when extraction fails. This type must implement `std::error::Error`. type Rejection; /// Extract this type from the server context. async fn from_request(req: &DioxusServerContext) -> Result; } /// A type was not found in the server context pub struct NotFoundInServerContext(std::marker::PhantomData); impl std::fmt::Debug for NotFoundInServerContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let type_name = std::any::type_name::(); write!(f, "`{type_name}` not found in server context") } } impl std::fmt::Display for NotFoundInServerContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let type_name = std::any::type_name::(); write!(f, "`{type_name}` not found in server context") } } impl std::error::Error for NotFoundInServerContext {} /// Extract a value from the server context provided through the launch builder context or [`DioxusServerContext::insert`] /// /// Example: /// ```rust, no_run /// use dioxus::prelude::*; /// /// dioxus::LaunchBuilder::new() /// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder /// .with_context(server_only! { /// 1234567890u32 /// }) /// .launch(app); /// /// #[server] /// async fn read_context() -> Result { /// // You can extract values from the server context with the `extract` function /// let FromContext(value) = extract().await?; /// Ok(value) /// } /// /// fn app() -> Element { /// let future = use_resource(read_context); /// rsx! { /// h1 { "{future:?}" } /// } /// } /// ``` pub struct FromContext(pub T); #[async_trait::async_trait] impl FromServerContext for FromContext { type Rejection = NotFoundInServerContext; async fn from_request(req: &DioxusServerContext) -> Result { Ok(Self(req.get::().ok_or({ NotFoundInServerContext::(std::marker::PhantomData::) })?)) } } #[cfg(feature = "axum")] #[cfg_attr(docsrs, doc(cfg(feature = "axum")))] /// An adapter for axum extractors for the server context pub struct Axum; #[cfg(feature = "axum")] #[async_trait::async_trait] impl< I: axum::extract::FromRequestParts<(), Rejection = R>, R: axum::response::IntoResponse + std::error::Error, > FromServerContext for I { type Rejection = R; #[allow(clippy::all)] async fn from_request(req: &DioxusServerContext) -> Result { let mut lock = req.request_parts_mut(); I::from_request_parts(&mut lock, &()).await } }