axum_adapter.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. //! Dioxus utilities for the [Axum](https://docs.rs/axum/latest/axum/index.html) server framework.
  2. //!
  3. //! # Example
  4. //! ```rust
  5. //! #![allow(non_snake_case)]
  6. //! use dioxus_lib::prelude::*;
  7. //! use dioxus_fullstack::prelude::*;
  8. //!
  9. //! fn main() {
  10. //! #[cfg(feature = "web")]
  11. //! // Hydrate the application on the client
  12. //! dioxus_web::launch_cfg(app, dioxus_web::Config::new().hydrate(true));
  13. //! #[cfg(feature = "server")]
  14. //! {
  15. //! tokio::runtime::Runtime::new()
  16. //! .unwrap()
  17. //! .block_on(async move {
  18. //! let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  19. //! axum::Server::bind(&addr)
  20. //! .serve(
  21. //! axum::Router::new()
  22. //! // Server side render the application, serve static assets, and register server functions
  23. //! .serve_dioxus_application("", ServerConfig::new(app, ()))
  24. //! .into_make_service(),
  25. //! )
  26. //! .await
  27. //! .unwrap();
  28. //! });
  29. //! }
  30. //! }
  31. //!
  32. //! fn app() -> Element {
  33. //! let text = use_signal(|| "...".to_string());
  34. //!
  35. //! rsx! {
  36. //! button {
  37. //! onclick: move |_| {
  38. //! to_owned![text];
  39. //! async move {
  40. //! if let Ok(data) = get_server_data().await {
  41. //! text.set(data);
  42. //! }
  43. //! }
  44. //! },
  45. //! "Run a server function"
  46. //! }
  47. //! "Server said: {text}"
  48. //! })
  49. //! }
  50. //!
  51. //! #[server(GetServerData)]
  52. //! async fn get_server_data() -> Result<String, ServerFnError> {
  53. //! Ok("Hello from the server!".to_string())
  54. //! }
  55. //! ```
  56. use axum::routing::*;
  57. use axum::{
  58. body::{self, Body},
  59. extract::State,
  60. http::{Request, Response, StatusCode},
  61. response::IntoResponse,
  62. routing::{get, post},
  63. Router,
  64. };
  65. use dioxus_lib::prelude::VirtualDom;
  66. use http::header::*;
  67. use server_fn::error::NoCustomError;
  68. use server_fn::error::ServerFnErrorSerde;
  69. use std::sync::Arc;
  70. use crate::{
  71. prelude::*, render::SSRState, serve_config::ServeConfig, server_context::DioxusServerContext,
  72. };
  73. /// A extension trait with utilities for integrating Dioxus with your Axum router.
  74. pub trait DioxusRouterExt<S> {
  75. /// Registers server functions with the default handler. This handler function will pass an empty [`DioxusServerContext`] to your server functions.
  76. ///
  77. /// # Example
  78. /// ```rust
  79. /// use dioxus_lib::prelude::*;
  80. /// use dioxus_fullstack::prelude::*;
  81. ///
  82. /// #[tokio::main]
  83. /// async fn main() {
  84. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  85. /// axum::Server::bind(&addr)
  86. /// .serve(
  87. /// axum::Router::new()
  88. /// // Register server functions routes with the default handler
  89. /// .register_server_fns("")
  90. /// .into_make_service(),
  91. /// )
  92. /// .await
  93. /// .unwrap();
  94. /// }
  95. /// ```
  96. fn register_server_fns(self) -> Self;
  97. /// Register the web RSX hot reloading endpoint. This will enable hot reloading for your application in debug mode when you call [`dioxus_hot_reload::hot_reload_init`].
  98. ///
  99. /// # Example
  100. /// ```rust
  101. /// #![allow(non_snake_case)]
  102. /// use dioxus_fullstack::prelude::*;
  103. ///
  104. /// #[tokio::main]
  105. /// async fn main() {
  106. /// hot_reload_init!();
  107. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  108. /// axum::Server::bind(&addr)
  109. /// .serve(
  110. /// axum::Router::new()
  111. /// // Connect to hot reloading in debug mode
  112. /// .connect_hot_reload()
  113. /// .into_make_service(),
  114. /// )
  115. /// .await
  116. /// .unwrap();
  117. /// }
  118. /// ```
  119. fn connect_hot_reload(self) -> Self;
  120. /// Serves the static WASM for your Dioxus application (except the generated index.html).
  121. ///
  122. /// # Example
  123. /// ```rust
  124. /// #![allow(non_snake_case)]
  125. /// use dioxus_lib::prelude::*;
  126. /// use dioxus_fullstack::prelude::*;
  127. ///
  128. /// #[tokio::main]
  129. /// async fn main() {
  130. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  131. /// axum::Server::bind(&addr)
  132. /// .serve(
  133. /// axum::Router::new()
  134. /// // Server side render the application, serve static assets, and register server functions
  135. /// .serve_static_assets("dist")
  136. /// // Server render the application
  137. /// // ...
  138. /// .into_make_service(),
  139. /// )
  140. /// .await
  141. /// .unwrap();
  142. /// }
  143. ///
  144. /// fn app() -> Element {
  145. /// unimplemented!()
  146. /// }
  147. /// ```
  148. fn serve_static_assets(self, assets_path: impl Into<std::path::PathBuf>) -> Self;
  149. /// Serves the Dioxus application. This will serve a complete server side rendered application.
  150. /// This will serve static assets, server render the application, register server functions, and intigrate with hot reloading.
  151. ///
  152. /// # Example
  153. /// ```rust
  154. /// #![allow(non_snake_case)]
  155. /// use dioxus_lib::prelude::*;
  156. /// use dioxus_fullstack::prelude::*;
  157. ///
  158. /// #[tokio::main]
  159. /// async fn main() {
  160. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  161. /// axum::Server::bind(&addr)
  162. /// .serve(
  163. /// axum::Router::new()
  164. /// // Server side render the application, serve static assets, and register server functions
  165. /// .serve_dioxus_application("", ServerConfig::new(app, ()))
  166. /// .into_make_service(),
  167. /// )
  168. /// .await
  169. /// .unwrap();
  170. /// }
  171. ///
  172. /// fn app() -> Element {
  173. /// unimplemented!()
  174. /// }
  175. /// ```
  176. fn serve_dioxus_application(
  177. self,
  178. cfg: impl Into<ServeConfig>,
  179. build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
  180. ) -> Self;
  181. }
  182. impl<S> DioxusRouterExt<S> for Router<S>
  183. where
  184. S: Send + Sync + Clone + 'static,
  185. {
  186. fn register_server_fns(mut self) -> Self {
  187. use http::method::Method;
  188. for (path, method) in server_fn::axum::server_fn_paths() {
  189. tracing::trace!("Registering server function: {} {}", method, path);
  190. let handler = move |req| handle_server_fns_inner(path, || {}, req);
  191. self = match method {
  192. Method::GET => self.route(path, get(handler)),
  193. Method::POST => self.route(path, post(handler)),
  194. Method::PUT => self.route(path, put(handler)),
  195. _ => todo!(),
  196. };
  197. }
  198. self
  199. }
  200. fn serve_static_assets(mut self, assets_path: impl Into<std::path::PathBuf>) -> Self {
  201. use tower_http::services::{ServeDir, ServeFile};
  202. let assets_path = assets_path.into();
  203. // Serve all files in dist folder except index.html
  204. let dir = std::fs::read_dir(&assets_path).unwrap_or_else(|e| {
  205. panic!(
  206. "Couldn't read assets directory at {:?}: {}",
  207. &assets_path, e
  208. )
  209. });
  210. for entry in dir.flatten() {
  211. let path = entry.path();
  212. if path.ends_with("index.html") {
  213. continue;
  214. }
  215. let route = path
  216. .strip_prefix(&assets_path)
  217. .unwrap()
  218. .iter()
  219. .map(|segment| {
  220. segment.to_str().unwrap_or_else(|| {
  221. panic!("Failed to convert path segment {:?} to string", segment)
  222. })
  223. })
  224. .collect::<Vec<_>>()
  225. .join("/");
  226. let route = format!("/{}", route);
  227. if path.is_dir() {
  228. self = self.nest_service(&route, ServeDir::new(path));
  229. } else {
  230. self = self.nest_service(&route, ServeFile::new(path));
  231. }
  232. }
  233. self
  234. }
  235. fn serve_dioxus_application(
  236. self,
  237. cfg: impl Into<ServeConfig>,
  238. build_virtual_dom: impl Fn() -> VirtualDom + Send + Sync + 'static,
  239. ) -> Self {
  240. let cfg = cfg.into();
  241. let ssr_state = SSRState::new(&cfg);
  242. // Add server functions and render index.html
  243. self.serve_static_assets(cfg.assets_path.clone())
  244. .connect_hot_reload()
  245. .register_server_fns()
  246. .fallback(get(render_handler).with_state((cfg, Arc::new(build_virtual_dom), ssr_state)))
  247. }
  248. fn connect_hot_reload(self) -> Self {
  249. #[cfg(all(debug_assertions, feature = "hot-reload"))]
  250. {
  251. self.nest(
  252. "/_dioxus",
  253. Router::new()
  254. .route(
  255. "/disconnect",
  256. get(|ws: axum::extract::WebSocketUpgrade| async {
  257. ws.on_upgrade(|mut ws| async move {
  258. use axum::extract::ws::Message;
  259. let _ = ws.send(Message::Text("connected".into())).await;
  260. loop {
  261. if ws.recv().await.is_none() {
  262. break;
  263. }
  264. }
  265. })
  266. }),
  267. )
  268. .route("/hot_reload", get(hot_reload_handler)),
  269. )
  270. }
  271. #[cfg(not(all(debug_assertions, feature = "hot-reload")))]
  272. {
  273. self
  274. }
  275. }
  276. }
  277. fn apply_request_parts_to_response<B>(
  278. headers: hyper::header::HeaderMap,
  279. response: &mut axum::response::Response<B>,
  280. ) {
  281. let mut_headers = response.headers_mut();
  282. for (key, value) in headers.iter() {
  283. mut_headers.insert(key, value.clone());
  284. }
  285. }
  286. type AxumHandler<F> = (
  287. F,
  288. ServeConfig,
  289. SSRState,
  290. Arc<dyn Fn() -> VirtualDom + Send + Sync>,
  291. );
  292. /// SSR renderer handler for Axum with added context injection.
  293. ///
  294. /// # Example
  295. /// ```rust,no_run
  296. /// #![allow(non_snake_case)]
  297. /// use std::sync::{Arc, Mutex};
  298. ///
  299. /// use axum::routing::get;
  300. /// use dioxus_lib::prelude::*;
  301. /// use dioxus_fullstack::{axum_adapter::render_handler_with_context, prelude::*};
  302. ///
  303. /// fn app() -> Element {
  304. /// rsx! {
  305. /// "hello!"
  306. /// }
  307. /// }
  308. ///
  309. /// #[tokio::main]
  310. /// async fn main() {
  311. /// let cfg = ServerConfig::new(app, ())
  312. /// .assets_path("dist")
  313. /// .build();
  314. /// let ssr_state = SSRState::new(&cfg);
  315. ///
  316. /// // This could be any state you want to be accessible from your server
  317. /// // functions using `[DioxusServerContext::get]`.
  318. /// let state = Arc::new(Mutex::new("state".to_string()));
  319. ///
  320. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  321. /// axum::Server::bind(&addr)
  322. /// .serve(
  323. /// axum::Router::new()
  324. /// // Register server functions, etc.
  325. /// // Note you probably want to use `register_server_fns_with_handler`
  326. /// // to inject the context into server functions running outside
  327. /// // of an SSR render context.
  328. /// .fallback(get(render_handler_with_context).with_state((
  329. /// move |ctx| ctx.insert(state.clone()).unwrap(),
  330. /// cfg,
  331. /// ssr_state,
  332. /// )))
  333. /// .into_make_service(),
  334. /// )
  335. /// .await
  336. /// .unwrap();
  337. /// }
  338. /// ```
  339. pub async fn render_handler_with_context<F: FnMut(&mut DioxusServerContext)>(
  340. State((mut inject_context, cfg, ssr_state, virtual_dom_factory)): State<AxumHandler<F>>,
  341. request: Request<Body>,
  342. ) -> impl IntoResponse {
  343. let (parts, _) = request.into_parts();
  344. let url = parts.uri.path_and_query().unwrap().to_string();
  345. let parts: Arc<tokio::sync::RwLock<http::request::Parts>> =
  346. Arc::new(tokio::sync::RwLock::new(parts));
  347. let mut server_context = DioxusServerContext::new(parts.clone());
  348. inject_context(&mut server_context);
  349. match ssr_state
  350. .render(url, &cfg, move || virtual_dom_factory(), &server_context)
  351. .await
  352. {
  353. Ok(rendered) => {
  354. let crate::render::RenderResponse { html, freshness } = rendered;
  355. let mut response = axum::response::Html::from(html).into_response();
  356. freshness.write(response.headers_mut());
  357. let headers = server_context.response_parts().unwrap().headers.clone();
  358. apply_request_parts_to_response(headers, &mut response);
  359. response
  360. }
  361. Err(e) => {
  362. tracing::error!("Failed to render page: {}", e);
  363. report_err(e).into_response()
  364. }
  365. }
  366. }
  367. type RenderHandlerExtractor = (
  368. ServeConfig,
  369. Arc<dyn Fn() -> VirtualDom + Send + Sync>,
  370. SSRState,
  371. );
  372. /// SSR renderer handler for Axum
  373. pub async fn render_handler(
  374. State((cfg, virtual_dom_factory, ssr_state)): State<RenderHandlerExtractor>,
  375. request: Request<Body>,
  376. ) -> impl IntoResponse {
  377. render_handler_with_context(
  378. State((|_: &mut _| (), cfg, ssr_state, virtual_dom_factory)),
  379. request,
  380. )
  381. .await
  382. }
  383. fn report_err<E: std::fmt::Display>(e: E) -> Response<axum::body::Body> {
  384. Response::builder()
  385. .status(StatusCode::INTERNAL_SERVER_ERROR)
  386. .body(body::Body::new(format!("Error: {}", e)))
  387. .unwrap()
  388. }
  389. /// A handler for Dioxus web hot reload websocket. This will send the updated static parts of the RSX to the client when they change.
  390. #[cfg(all(debug_assertions, feature = "hot-reload"))]
  391. pub async fn hot_reload_handler(ws: axum::extract::WebSocketUpgrade) -> impl IntoResponse {
  392. use axum::extract::ws::Message;
  393. use futures_util::StreamExt;
  394. let state = crate::hot_reload::spawn_hot_reload().await;
  395. ws.on_upgrade(move |mut socket| async move {
  396. println!("🔥 Hot Reload WebSocket connected");
  397. {
  398. // update any rsx calls that changed before the websocket connected.
  399. {
  400. println!("🔮 Finding updates since last compile...");
  401. let templates_read = state.templates.read().await;
  402. for template in &*templates_read {
  403. if socket
  404. .send(Message::Text(serde_json::to_string(&template).unwrap()))
  405. .await
  406. .is_err()
  407. {
  408. return;
  409. }
  410. }
  411. }
  412. println!("finished");
  413. }
  414. let mut rx =
  415. tokio_stream::wrappers::WatchStream::from_changes(state.message_receiver.clone());
  416. while let Some(change) = rx.next().await {
  417. if let Some(template) = change {
  418. let template = { serde_json::to_string(&template).unwrap() };
  419. if socket.send(Message::Text(template)).await.is_err() {
  420. break;
  421. };
  422. }
  423. }
  424. })
  425. }
  426. fn get_local_pool() -> tokio_util::task::LocalPoolHandle {
  427. use once_cell::sync::OnceCell;
  428. static LOCAL_POOL: OnceCell<tokio_util::task::LocalPoolHandle> = OnceCell::new();
  429. LOCAL_POOL
  430. .get_or_init(|| {
  431. tokio_util::task::LocalPoolHandle::new(
  432. std::thread::available_parallelism()
  433. .map(Into::into)
  434. .unwrap_or(1),
  435. )
  436. })
  437. .clone()
  438. }
  439. /// A handler for Dioxus server functions. This will run the server function and return the result.
  440. async fn handle_server_fns_inner(
  441. path: &str,
  442. additional_context: impl Fn() + 'static + Clone + Send,
  443. req: Request<Body>,
  444. ) -> impl IntoResponse {
  445. use server_fn::middleware::Service;
  446. let (tx, rx) = tokio::sync::oneshot::channel();
  447. let path_string = path.to_string();
  448. get_local_pool().spawn_pinned(move || async move {
  449. let (parts, body) = req.into_parts();
  450. let req = Request::from_parts(parts.clone(), body);
  451. let res = if let Some(mut service) =
  452. server_fn::axum::get_server_fn_service(&path_string)
  453. {
  454. let server_context = DioxusServerContext::new(Arc::new(tokio::sync::RwLock::new(parts)));
  455. additional_context();
  456. // store Accepts and Referrer in case we need them for redirect (below)
  457. let accepts_html = req
  458. .headers()
  459. .get(ACCEPT)
  460. .and_then(|v| v.to_str().ok())
  461. .map(|v| v.contains("text/html"))
  462. .unwrap_or(false);
  463. let referrer = req.headers().get(REFERER).cloned();
  464. // actually run the server fn
  465. let mut res = service.run(req).await;
  466. // it it accepts text/html (i.e., is a plain form post) and doesn't already have a
  467. // Location set, then redirect to to Referer
  468. if accepts_html {
  469. if let Some(referrer) = referrer {
  470. let has_location = res.headers().get(LOCATION).is_some();
  471. if !has_location {
  472. *res.status_mut() = StatusCode::FOUND;
  473. res.headers_mut().insert(LOCATION, referrer);
  474. }
  475. }
  476. }
  477. // apply the response parts from the server context to the response
  478. let mut res_options = server_context.response_parts_mut().unwrap();
  479. res.headers_mut().extend(res_options.headers.drain());
  480. Ok(res)
  481. } else {
  482. Response::builder().status(StatusCode::BAD_REQUEST).body(
  483. {
  484. #[cfg(target_family = "wasm")]
  485. {
  486. Body::from(format!(
  487. "No server function found for path: {path_string}\nYou may need to explicitly register the server function with `register_explicit`, rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
  488. ))
  489. }
  490. #[cfg(not(target_family = "wasm"))]
  491. {
  492. Body::from(format!(
  493. "No server function found for path: {path_string}\nYou may need to rebuild your wasm binary to update a server function link or make sure the prefix your server and client use for server functions match.",
  494. ))
  495. }
  496. }
  497. )
  498. }
  499. .expect("could not build Response");
  500. _ = tx.send(res);
  501. });
  502. rx.await.unwrap_or_else(|e| {
  503. (
  504. StatusCode::INTERNAL_SERVER_ERROR,
  505. ServerFnError::<NoCustomError>::ServerError(e.to_string())
  506. .ser()
  507. .unwrap_or_default(),
  508. )
  509. .into_response()
  510. })
  511. }