axum_core.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. //! Dioxus core utilities for the [Axum](https://docs.rs/axum/latest/axum/index.html) server framework.
  2. //!
  3. //! # Example
  4. //! ```rust, no_run
  5. //! #![allow(non_snake_case)]
  6. //! use dioxus::prelude::*;
  7. //!
  8. //! fn main() {
  9. //! #[cfg(feature = "web")]
  10. //! // Hydrate the application on the client
  11. //! dioxus::launch(app);
  12. //! #[cfg(feature = "server")]
  13. //! {
  14. //! tokio::runtime::Runtime::new()
  15. //! .unwrap()
  16. //! .block_on(async move {
  17. //! // Get the address the server should run on. If the CLI is running, the CLI proxies fullstack into the main address
  18. //! // and we use the generated address the CLI gives us
  19. //! let address = dioxus::cli_config::fullstack_address_or_localhost();
  20. //! let listener = tokio::net::TcpListener::bind(address)
  21. //! .await
  22. //! .unwrap();
  23. //! axum::serve(
  24. //! listener,
  25. //! axum::Router::new()
  26. //! // Server side render the application, serve static assets, and register server functions
  27. //! .register_server_functions()
  28. //! .fallback(get(render_handler)
  29. //! // Note: ServeConfig::new won't work on WASM
  30. //! .with_state(RenderHandler::new(ServeConfig::new().unwrap(), app))
  31. //! )
  32. //! .into_make_service(),
  33. //! )
  34. //! .await
  35. //! .unwrap();
  36. //! });
  37. //! }
  38. //! }
  39. //!
  40. //! fn app() -> Element {
  41. //! let mut text = use_signal(|| "...".to_string());
  42. //!
  43. //! rsx! {
  44. //! button {
  45. //! onclick: move |_| async move {
  46. //! if let Ok(data) = get_server_data().await {
  47. //! text.set(data);
  48. //! }
  49. //! },
  50. //! "Run a server function"
  51. //! }
  52. //! "Server said: {text}"
  53. //! }
  54. //! }
  55. //!
  56. //! #[server(GetServerData)]
  57. //! async fn get_server_data() -> Result<String, ServerFnError> {
  58. //! Ok("Hello from the server!".to_string())
  59. //! }
  60. //!
  61. //! # WASM support
  62. //!
  63. //! These utilities compile to the WASM family of targets, while the more complete ones found in [server] don't
  64. //! ```
  65. use std::sync::Arc;
  66. use crate::prelude::*;
  67. use crate::render::SSRError;
  68. use crate::ContextProviders;
  69. use axum::body;
  70. use axum::extract::State;
  71. use axum::routing::*;
  72. use axum::{
  73. body::Body,
  74. http::{Request, Response, StatusCode},
  75. response::IntoResponse,
  76. };
  77. use dioxus_lib::prelude::{Element, VirtualDom};
  78. use http::header::*;
  79. /// A extension trait with server function utilities for integrating Dioxus with your Axum router.
  80. pub trait DioxusRouterFnExt<S> {
  81. /// Registers server functions with the default handler. This handler function will pass an empty [`DioxusServerContext`] to your server functions.
  82. ///
  83. /// # Example
  84. /// ```rust, no_run
  85. /// # use dioxus_lib::prelude::*;
  86. /// # use dioxus_fullstack::prelude::*;
  87. /// #[tokio::main]
  88. /// async fn main() {
  89. /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
  90. /// let router = axum::Router::new()
  91. /// // Register server functions routes with the default handler
  92. /// .register_server_functions()
  93. /// .into_make_service();
  94. /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
  95. /// axum::serve(listener, router).await.unwrap();
  96. /// }
  97. /// ```
  98. #[allow(dead_code)]
  99. fn register_server_functions(self) -> Self
  100. where
  101. Self: Sized,
  102. {
  103. self.register_server_functions_with_context(Default::default())
  104. }
  105. /// Registers server functions with some additional context to insert into the [`DioxusServerContext`] for that handler.
  106. ///
  107. /// # Example
  108. /// ```rust, no_run
  109. /// # use dioxus_lib::prelude::*;
  110. /// # use dioxus_fullstack::prelude::*;
  111. /// # use std::sync::Arc;
  112. /// #[tokio::main]
  113. /// async fn main() {
  114. /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
  115. /// let router = axum::Router::new()
  116. /// // Register server functions routes with the default handler
  117. /// .register_server_functions_with_context(Arc::new(vec![Box::new(|| Box::new(1234567890u32))]))
  118. /// .into_make_service();
  119. /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
  120. /// axum::serve(listener, router).await.unwrap();
  121. /// }
  122. /// ```
  123. fn register_server_functions_with_context(self, context_providers: ContextProviders) -> Self;
  124. }
  125. impl<S> DioxusRouterFnExt<S> for Router<S>
  126. where
  127. S: Send + Sync + Clone + 'static,
  128. {
  129. fn register_server_functions_with_context(
  130. mut self,
  131. context_providers: ContextProviders,
  132. ) -> Self {
  133. use http::method::Method;
  134. for (path, method) in server_fn::axum::server_fn_paths() {
  135. tracing::trace!("Registering server function: {} {}", method, path);
  136. let context_providers = context_providers.clone();
  137. let handler = move |req| handle_server_fns_inner(path, context_providers, req);
  138. self = match method {
  139. Method::GET => self.route(path, get(handler)),
  140. Method::POST => self.route(path, post(handler)),
  141. Method::PUT => self.route(path, put(handler)),
  142. _ => unimplemented!("Unsupported server function method: {}", method),
  143. };
  144. }
  145. self
  146. }
  147. }
  148. /// A handler for Dioxus server functions. This will run the server function and return the result.
  149. async fn handle_server_fns_inner(
  150. path: &str,
  151. additional_context: ContextProviders,
  152. req: Request<Body>,
  153. ) -> impl IntoResponse {
  154. let path_string = path.to_string();
  155. let (parts, body) = req.into_parts();
  156. let req = Request::from_parts(parts.clone(), body);
  157. let method = req.method().clone();
  158. if let Some(mut service) =
  159. server_fn::axum::get_server_fn_service(&path_string, method)
  160. {
  161. // Create the server context with info from the request
  162. let server_context = DioxusServerContext::new(parts);
  163. // Provide additional context from the render state
  164. add_server_context(&server_context, &additional_context);
  165. // store Accepts and Referrer in case we need them for redirect (below)
  166. let accepts_html = req
  167. .headers()
  168. .get(ACCEPT)
  169. .and_then(|v| v.to_str().ok())
  170. .map(|v| v.contains("text/html"))
  171. .unwrap_or(false);
  172. let referrer = req.headers().get(REFERER).cloned();
  173. // actually run the server fn (which may use the server context)
  174. let fut = with_server_context(server_context.clone(), || service.run(req));
  175. let mut res = ProvideServerContext::new(fut, server_context.clone()).await;
  176. // it it accepts text/html (i.e., is a plain form post) and doesn't already have a
  177. // Location set, then redirect to Referer
  178. if accepts_html {
  179. if let Some(referrer) = referrer {
  180. let has_location = res.headers().get(LOCATION).is_some();
  181. if !has_location {
  182. *res.status_mut() = StatusCode::FOUND;
  183. res.headers_mut().insert(LOCATION, referrer);
  184. }
  185. }
  186. }
  187. // apply the response parts from the server context to the response
  188. server_context.send_response(&mut res);
  189. Ok(res)
  190. } else {
  191. Response::builder().status(StatusCode::BAD_REQUEST).body(
  192. {
  193. #[cfg(target_family = "wasm")]
  194. {
  195. Body::from(format!(
  196. "No server function found for path: {path_string}\nYou may need to explicitly register the server function with `register_explicit`, rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
  197. ))
  198. }
  199. #[cfg(not(target_family = "wasm"))]
  200. {
  201. Body::from(format!(
  202. "No server function found for path: {path_string}\nYou may need to rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
  203. ))
  204. }
  205. }
  206. )
  207. }
  208. .expect("could not build Response")
  209. }
  210. pub(crate) fn add_server_context(
  211. server_context: &DioxusServerContext,
  212. context_providers: &ContextProviders,
  213. ) {
  214. for index in 0..context_providers.len() {
  215. let context_providers = context_providers.clone();
  216. server_context.insert_boxed_factory(Box::new(move || context_providers[index]()));
  217. }
  218. }
  219. /// State used by [`render_handler`] to render a dioxus component with axum
  220. #[derive(Clone)]
  221. pub struct RenderHandleState {
  222. config: ServeConfig,
  223. build_virtual_dom: Arc<dyn Fn() -> VirtualDom + Send + Sync>,
  224. ssr_state: once_cell::sync::OnceCell<SSRState>,
  225. }
  226. impl RenderHandleState {
  227. /// Create a new [`RenderHandleState`]
  228. pub fn new(config: ServeConfig, root: fn() -> Element) -> Self {
  229. Self {
  230. config,
  231. build_virtual_dom: Arc::new(move || VirtualDom::new(root)),
  232. ssr_state: Default::default(),
  233. }
  234. }
  235. /// Create a new [`RenderHandleState`] with a custom [`VirtualDom`] factory. This method can be used to pass context into the root component of your application.
  236. pub fn new_with_virtual_dom_factory(
  237. config: ServeConfig,
  238. build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
  239. ) -> Self {
  240. Self {
  241. config,
  242. build_virtual_dom: Arc::new(build_virtual_dom),
  243. ssr_state: Default::default(),
  244. }
  245. }
  246. /// Set the [`ServeConfig`] for this [`RenderHandleState`]
  247. pub fn with_config(mut self, config: ServeConfig) -> Self {
  248. self.config = config;
  249. self
  250. }
  251. /// Set the [`SSRState`] for this [`RenderHandleState`]. Sharing a [`SSRState`] between multiple [`RenderHandleState`]s is more efficient than creating a new [`SSRState`] for each [`RenderHandleState`].
  252. pub fn with_ssr_state(mut self, ssr_state: SSRState) -> Self {
  253. self.ssr_state = once_cell::sync::OnceCell::new();
  254. if self.ssr_state.set(ssr_state).is_err() {
  255. panic!("SSRState already set");
  256. }
  257. self
  258. }
  259. fn ssr_state(&self) -> &SSRState {
  260. self.ssr_state.get_or_init(|| SSRState::new(&self.config))
  261. }
  262. }
  263. /// SSR renderer handler for Axum with added context injection.
  264. ///
  265. /// # Example
  266. /// ```rust,no_run
  267. /// #![allow(non_snake_case)]
  268. /// use std::sync::{Arc, Mutex};
  269. ///
  270. /// use axum::routing::get;
  271. /// use dioxus::prelude::*;
  272. ///
  273. /// fn app() -> Element {
  274. /// rsx! {
  275. /// "hello!"
  276. /// }
  277. /// }
  278. ///
  279. /// #[tokio::main]
  280. /// async fn main() {
  281. /// let addr = dioxus::cli_config::fullstack_address_or_localhost();
  282. /// let router = axum::Router::new()
  283. /// // Register server functions, etc.
  284. /// // Note you can use `register_server_functions_with_context`
  285. /// // to inject the context into server functions running outside
  286. /// // of an SSR render context.
  287. /// .fallback(get(render_handler)
  288. /// .with_state(RenderHandleState::new(ServeConfig::new().unwrap(), app))
  289. /// )
  290. /// .into_make_service();
  291. /// let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
  292. /// axum::serve(listener, router).await.unwrap();
  293. /// }
  294. /// ```
  295. pub async fn render_handler(
  296. State(state): State<RenderHandleState>,
  297. request: Request<Body>,
  298. ) -> impl IntoResponse {
  299. let cfg = &state.config;
  300. let ssr_state = state.ssr_state();
  301. let build_virtual_dom = {
  302. let build_virtual_dom = state.build_virtual_dom.clone();
  303. let context_providers = state.config.context_providers.clone();
  304. move || {
  305. let mut vdom = build_virtual_dom();
  306. for state in context_providers.as_slice() {
  307. vdom.insert_any_root_context(state());
  308. }
  309. vdom
  310. }
  311. };
  312. let (parts, _) = request.into_parts();
  313. let url = parts
  314. .uri
  315. .path_and_query()
  316. .ok_or(StatusCode::BAD_REQUEST)?
  317. .to_string();
  318. let parts: Arc<parking_lot::RwLock<http::request::Parts>> =
  319. Arc::new(parking_lot::RwLock::new(parts));
  320. // Create the server context with info from the request
  321. let server_context = DioxusServerContext::from_shared_parts(parts.clone());
  322. // Provide additional context from the render state
  323. add_server_context(&server_context, &state.config.context_providers);
  324. match ssr_state
  325. .render(url, cfg, build_virtual_dom, &server_context)
  326. .await
  327. {
  328. Ok((freshness, rx)) => {
  329. let mut response = axum::response::Html::from(Body::from_stream(rx)).into_response();
  330. freshness.write(response.headers_mut());
  331. server_context.send_response(&mut response);
  332. Result::<http::Response<axum::body::Body>, StatusCode>::Ok(response)
  333. }
  334. Err(SSRError::Incremental(e)) => {
  335. tracing::error!("Failed to render page: {}", e);
  336. Ok(report_err(e).into_response())
  337. }
  338. Err(SSRError::Routing(e)) => {
  339. tracing::trace!("Page not found: {}", e);
  340. Ok(Response::builder()
  341. .status(StatusCode::NOT_FOUND)
  342. .body(Body::from("Page not found"))
  343. .unwrap())
  344. }
  345. }
  346. }
  347. fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
  348. Response::builder()
  349. .status(StatusCode::INTERNAL_SERVER_ERROR)
  350. .body(body::Body::new(format!("Error: {}", e)))
  351. .unwrap()
  352. }