server_context.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. use parking_lot::RwLock;
  2. use std::any::Any;
  3. use std::collections::HashMap;
  4. use std::sync::Arc;
  5. type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
  6. /// A shared context for server functions that contains information about the request and middleware state.
  7. ///
  8. /// You should not construct this directly inside components or server functions. Instead use [`server_context()`] to get the server context from the current request.
  9. ///
  10. /// # Example
  11. ///
  12. /// ```rust, no_run
  13. /// # use dioxus::prelude::*;
  14. /// #[server]
  15. /// async fn read_headers() -> Result<(), ServerFnError> {
  16. /// let server_context = server_context();
  17. /// let headers: http::HeaderMap = server_context.extract().await?;
  18. /// println!("{:?}", headers);
  19. /// Ok(())
  20. /// }
  21. /// ```
  22. #[derive(Clone)]
  23. pub struct DioxusServerContext {
  24. shared_context: std::sync::Arc<RwLock<SendSyncAnyMap>>,
  25. response_parts: std::sync::Arc<RwLock<http::response::Parts>>,
  26. pub(crate) parts: Arc<RwLock<http::request::Parts>>,
  27. }
  28. enum ContextType {
  29. Factory(Box<dyn Fn() -> Box<dyn Any> + Send + Sync>),
  30. Value(Box<dyn Any + Send + Sync>),
  31. }
  32. impl ContextType {
  33. fn downcast<T: Clone + 'static>(&self) -> Option<T> {
  34. match self {
  35. ContextType::Value(value) => value.downcast_ref::<T>().cloned(),
  36. ContextType::Factory(factory) => factory().downcast::<T>().ok().map(|v| *v),
  37. }
  38. }
  39. }
  40. #[allow(clippy::derivable_impls)]
  41. impl Default for DioxusServerContext {
  42. fn default() -> Self {
  43. Self {
  44. shared_context: std::sync::Arc::new(RwLock::new(HashMap::new())),
  45. response_parts: std::sync::Arc::new(RwLock::new(
  46. http::response::Response::new(()).into_parts().0,
  47. )),
  48. parts: std::sync::Arc::new(RwLock::new(http::request::Request::new(()).into_parts().0)),
  49. }
  50. }
  51. }
  52. mod server_fn_impl {
  53. use super::*;
  54. use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
  55. use std::any::{Any, TypeId};
  56. impl DioxusServerContext {
  57. /// Create a new server context from a request
  58. pub fn new(parts: http::request::Parts) -> Self {
  59. Self {
  60. parts: Arc::new(RwLock::new(parts)),
  61. shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
  62. response_parts: std::sync::Arc::new(RwLock::new(
  63. http::response::Response::new(()).into_parts().0,
  64. )),
  65. }
  66. }
  67. /// Create a server context from a shared parts
  68. #[allow(unused)]
  69. pub(crate) fn from_shared_parts(parts: Arc<RwLock<http::request::Parts>>) -> Self {
  70. Self {
  71. parts,
  72. shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
  73. response_parts: std::sync::Arc::new(RwLock::new(
  74. http::response::Response::new(()).into_parts().0,
  75. )),
  76. }
  77. }
  78. /// Clone a value from the shared server context. If you are using [`DioxusRouterExt`](crate::prelude::DioxusRouterExt), any values you insert into
  79. /// the launch context will also be available in the server context.
  80. ///
  81. /// Example:
  82. /// ```rust, no_run
  83. /// use dioxus::prelude::*;
  84. ///
  85. /// LaunchBuilder::new()
  86. /// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
  87. /// .with_context(server_only! {
  88. /// 1234567890u32
  89. /// })
  90. /// .launch(app);
  91. ///
  92. /// #[server]
  93. /// async fn read_context() -> Result<u32, ServerFnError> {
  94. /// // You can extract values from the server context with the `extract` function
  95. /// let FromContext(value) = extract().await?;
  96. /// Ok(value)
  97. /// }
  98. ///
  99. /// fn app() -> Element {
  100. /// let future = use_resource(read_context);
  101. /// rsx! {
  102. /// h1 { "{future:?}" }
  103. /// }
  104. /// }
  105. /// ```
  106. pub fn get<T: Any + Send + Sync + Clone + 'static>(&self) -> Option<T> {
  107. self.shared_context
  108. .read()
  109. .get(&TypeId::of::<T>())
  110. .map(|v| v.downcast::<T>().unwrap())
  111. }
  112. /// Insert a value into the shared server context
  113. pub fn insert<T: Any + Send + Sync + 'static>(&self, value: T) {
  114. self.insert_any(Box::new(value));
  115. }
  116. /// Insert a boxed `Any` value into the shared server context
  117. pub fn insert_any(&self, value: Box<dyn Any + Send + Sync + 'static>) {
  118. self.shared_context
  119. .write()
  120. .insert((*value).type_id(), ContextType::Value(value));
  121. }
  122. /// Insert a factory that creates a non-sync value for the shared server context
  123. pub fn insert_factory<F, T>(&self, value: F)
  124. where
  125. F: Fn() -> T + Send + Sync + 'static,
  126. T: 'static,
  127. {
  128. self.shared_context.write().insert(
  129. TypeId::of::<T>(),
  130. ContextType::Factory(Box::new(move || Box::new(value()))),
  131. );
  132. }
  133. /// Insert a boxed factory that creates a non-sync value for the shared server context
  134. pub fn insert_boxed_factory(&self, value: Box<dyn Fn() -> Box<dyn Any> + Send + Sync>) {
  135. self.shared_context
  136. .write()
  137. .insert((*value()).type_id(), ContextType::Factory(value));
  138. }
  139. /// Get the response parts from the server context
  140. ///
  141. #[doc = include_str!("../docs/request_origin.md")]
  142. ///
  143. /// # Example
  144. ///
  145. /// ```rust, no_run
  146. /// # use dioxus::prelude::*;
  147. /// #[server]
  148. /// async fn set_headers() -> Result<(), ServerFnError> {
  149. /// let server_context = server_context();
  150. /// let cookies = server_context.response_parts()
  151. /// .headers()
  152. /// .get("Cookie")
  153. /// .ok_or_else(|| ServerFnError::msg("failed to find Cookie header in the response"))?;
  154. /// println!("{:?}", cookies);
  155. /// Ok(())
  156. /// }
  157. /// ```
  158. pub fn response_parts(&self) -> RwLockReadGuard<'_, http::response::Parts> {
  159. self.response_parts.read()
  160. }
  161. /// Get the response parts from the server context
  162. ///
  163. #[doc = include_str!("../docs/request_origin.md")]
  164. ///
  165. /// # Example
  166. ///
  167. /// ```rust, no_run
  168. /// # use dioxus::prelude::*;
  169. /// #[server]
  170. /// async fn set_headers() -> Result<(), ServerFnError> {
  171. /// let server_context = server_context();
  172. /// server_context.response_parts_mut()
  173. /// .headers_mut()
  174. /// .insert("Cookie", "dioxus=fullstack");
  175. /// Ok(())
  176. /// }
  177. /// ```
  178. pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
  179. self.response_parts.write()
  180. }
  181. /// Get the request parts
  182. ///
  183. #[doc = include_str!("../docs/request_origin.md")]
  184. ///
  185. /// # Example
  186. ///
  187. /// ```rust, no_run
  188. /// # use dioxus::prelude::*;
  189. /// #[server]
  190. /// async fn read_headers() -> Result<(), ServerFnError> {
  191. /// let server_context = server_context();
  192. /// let id: &i32 = server_context.request_parts()
  193. /// .extensions
  194. /// .get()
  195. /// .ok_or_else(|| ServerFnError::msg("failed to find i32 extension in the request"))?;
  196. /// println!("{:?}", id);
  197. /// Ok(())
  198. /// }
  199. /// ```
  200. pub fn request_parts(&self) -> parking_lot::RwLockReadGuard<'_, http::request::Parts> {
  201. self.parts.read()
  202. }
  203. /// Get the request parts mutably
  204. ///
  205. #[doc = include_str!("../docs/request_origin.md")]
  206. ///
  207. /// # Example
  208. ///
  209. /// ```rust, no_run
  210. /// # use dioxus::prelude::*;
  211. /// #[server]
  212. /// async fn read_headers() -> Result<(), ServerFnError> {
  213. /// let server_context = server_context();
  214. /// let id: i32 = server_context.request_parts_mut()
  215. /// .extensions
  216. /// .remove()
  217. /// .ok_or_else(|| ServerFnError::msg("failed to find i32 extension in the request"))?;
  218. /// println!("{:?}", id);
  219. /// Ok(())
  220. /// }
  221. /// ```
  222. pub fn request_parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
  223. self.parts.write()
  224. }
  225. /// Extract part of the request.
  226. ///
  227. #[doc = include_str!("../docs/request_origin.md")]
  228. ///
  229. /// # Example
  230. ///
  231. /// ```rust, no_run
  232. /// # use dioxus::prelude::*;
  233. /// #[server]
  234. /// async fn read_headers() -> Result<(), ServerFnError> {
  235. /// let server_context = server_context();
  236. /// let headers: http::HeaderMap = server_context.extract().await?;
  237. /// println!("{:?}", headers);
  238. /// Ok(())
  239. /// }
  240. /// ```
  241. pub async fn extract<M, T: FromServerContext<M>>(&self) -> Result<T, T::Rejection> {
  242. T::from_request(self).await
  243. }
  244. }
  245. }
  246. std::thread_local! {
  247. pub(crate) static SERVER_CONTEXT: std::cell::RefCell<Box<DioxusServerContext>> = Default::default();
  248. }
  249. /// Get information about the current server request.
  250. ///
  251. /// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
  252. pub fn server_context() -> DioxusServerContext {
  253. SERVER_CONTEXT.with(|ctx| *ctx.borrow().clone())
  254. }
  255. /// Extract some part from the current server request.
  256. ///
  257. /// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
  258. pub async fn extract<E: FromServerContext<I>, I>() -> Result<E, E::Rejection> {
  259. E::from_request(&server_context()).await
  260. }
  261. /// Run a function inside of the server context.
  262. pub fn with_server_context<O>(context: DioxusServerContext, f: impl FnOnce() -> O) -> O {
  263. // before polling the future, we need to set the context
  264. let prev_context = SERVER_CONTEXT.with(|ctx| ctx.replace(Box::new(context)));
  265. // poll the future, which may call server_context()
  266. let result = f();
  267. // after polling the future, we need to restore the context
  268. SERVER_CONTEXT.with(|ctx| ctx.replace(prev_context));
  269. result
  270. }
  271. /// A future that provides the server context to the inner future
  272. #[pin_project::pin_project]
  273. pub struct ProvideServerContext<F: std::future::Future> {
  274. context: DioxusServerContext,
  275. #[pin]
  276. f: F,
  277. }
  278. impl<F: std::future::Future> ProvideServerContext<F> {
  279. /// Create a new future that provides the server context to the inner future
  280. pub fn new(f: F, context: DioxusServerContext) -> Self {
  281. Self { f, context }
  282. }
  283. }
  284. impl<F: std::future::Future> std::future::Future for ProvideServerContext<F> {
  285. type Output = F::Output;
  286. fn poll(
  287. self: std::pin::Pin<&mut Self>,
  288. cx: &mut std::task::Context<'_>,
  289. ) -> std::task::Poll<Self::Output> {
  290. let this = self.project();
  291. let context = this.context.clone();
  292. with_server_context(context, || this.f.poll(cx))
  293. }
  294. }
  295. /// A trait for extracting types from the server context
  296. #[async_trait::async_trait]
  297. pub trait FromServerContext<I = ()>: Sized {
  298. /// The error type returned when extraction fails. This type must implement `std::error::Error`.
  299. type Rejection;
  300. /// Extract this type from the server context.
  301. async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection>;
  302. }
  303. /// A type was not found in the server context
  304. pub struct NotFoundInServerContext<T: 'static>(std::marker::PhantomData<T>);
  305. impl<T: 'static> std::fmt::Debug for NotFoundInServerContext<T> {
  306. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  307. let type_name = std::any::type_name::<T>();
  308. write!(f, "`{type_name}` not found in server context")
  309. }
  310. }
  311. impl<T: 'static> std::fmt::Display for NotFoundInServerContext<T> {
  312. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  313. let type_name = std::any::type_name::<T>();
  314. write!(f, "`{type_name}` not found in server context")
  315. }
  316. }
  317. impl<T: 'static> std::error::Error for NotFoundInServerContext<T> {}
  318. /// Extract a value from the server context provided through the launch builder context or [`DioxusServerContext::insert`]
  319. ///
  320. /// Example:
  321. /// ```rust, no_run
  322. /// use dioxus::prelude::*;
  323. ///
  324. /// dioxus::LaunchBuilder::new()
  325. /// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
  326. /// .with_context(server_only! {
  327. /// 1234567890u32
  328. /// })
  329. /// .launch(app);
  330. ///
  331. /// #[server]
  332. /// async fn read_context() -> Result<u32, ServerFnError> {
  333. /// // You can extract values from the server context with the `extract` function
  334. /// let FromContext(value) = extract().await?;
  335. /// Ok(value)
  336. /// }
  337. ///
  338. /// fn app() -> Element {
  339. /// let future = use_resource(read_context);
  340. /// rsx! {
  341. /// h1 { "{future:?}" }
  342. /// }
  343. /// }
  344. /// ```
  345. pub struct FromContext<T: std::marker::Send + std::marker::Sync + Clone + 'static>(pub T);
  346. #[async_trait::async_trait]
  347. impl<T: Send + Sync + Clone + 'static> FromServerContext for FromContext<T> {
  348. type Rejection = NotFoundInServerContext<T>;
  349. async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
  350. Ok(Self(req.get::<T>().ok_or({
  351. NotFoundInServerContext::<T>(std::marker::PhantomData::<T>)
  352. })?))
  353. }
  354. }
  355. #[cfg(feature = "axum")]
  356. #[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
  357. /// An adapter for axum extractors for the server context
  358. pub struct Axum;
  359. #[cfg(feature = "axum")]
  360. #[async_trait::async_trait]
  361. impl<
  362. I: axum::extract::FromRequestParts<(), Rejection = R>,
  363. R: axum::response::IntoResponse + std::error::Error,
  364. > FromServerContext<Axum> for I
  365. {
  366. type Rejection = R;
  367. #[allow(clippy::all)]
  368. async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
  369. let mut lock = req.request_parts_mut();
  370. I::from_request_parts(&mut lock, &()).await
  371. }
  372. }