axum_adapter.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. //! Dioxus utilities for the [Axum](https://docs.rs/axum/latest/axum/index.html) server framework.
  2. //!
  3. //! # Example
  4. //! ```rust
  5. //! #![allow(non_snake_case)]
  6. //! use dioxus_lib::prelude::*;
  7. //! use dioxus_fullstack::prelude::*;
  8. //!
  9. //! fn main() {
  10. //! #[cfg(feature = "web")]
  11. //! // Hydrate the application on the client
  12. //! dioxus_web::launch_cfg(app, dioxus_web::Config::new().hydrate(true));
  13. //! #[cfg(feature = "ssr")]
  14. //! {
  15. //! tokio::runtime::Runtime::new()
  16. //! .unwrap()
  17. //! .block_on(async move {
  18. //! let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  19. //! axum::Server::bind(&addr)
  20. //! .serve(
  21. //! axum::Router::new()
  22. //! // Server side render the application, serve static assets, and register server functions
  23. //! .serve_dioxus_application("", ServerConfig::new(app, ()))
  24. //! .into_make_service(),
  25. //! )
  26. //! .await
  27. //! .unwrap();
  28. //! });
  29. //! }
  30. //! }
  31. //!
  32. //! fn app() -> Element {
  33. //! let text = use_signal(|| "...".to_string());
  34. //!
  35. //! rsx! {
  36. //! button {
  37. //! onclick: move |_| {
  38. //! to_owned![text];
  39. //! async move {
  40. //! if let Ok(data) = get_server_data().await {
  41. //! text.set(data);
  42. //! }
  43. //! }
  44. //! },
  45. //! "Run a server function"
  46. //! }
  47. //! "Server said: {text}"
  48. //! })
  49. //! }
  50. //!
  51. //! #[server(GetServerData)]
  52. //! async fn get_server_data() -> Result<String, ServerFnError> {
  53. //! Ok("Hello from the server!".to_string())
  54. //! }
  55. //! ```
  56. use axum::{
  57. body::{self, Body, BoxBody},
  58. extract::State,
  59. handler::Handler,
  60. http::{Request, Response, StatusCode},
  61. response::IntoResponse,
  62. routing::{get, post},
  63. Router,
  64. };
  65. use server_fn::{Encoding, ServerFunctionRegistry};
  66. use std::sync::Arc;
  67. use std::sync::RwLock;
  68. use crate::{
  69. prelude::*, render::SSRState, serve_config::ServeConfig, server_context::DioxusServerContext,
  70. server_fn::DioxusServerFnRegistry,
  71. };
  72. /// A extension trait with utilities for integrating Dioxus with your Axum router.
  73. pub trait DioxusRouterExt<S> {
  74. /// Registers server functions with a custom handler function. This allows you to pass custom context to your server functions by generating a [`DioxusServerContext`] from the request.
  75. ///
  76. /// # Example
  77. /// ```rust
  78. /// use dioxus_lib::prelude::*;
  79. /// use dioxus_fullstack::prelude::*;
  80. ///
  81. /// #[tokio::main]
  82. /// async fn main() {
  83. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  84. /// axum::Server::bind(&addr)
  85. /// .serve(
  86. /// axum::Router::new()
  87. /// .register_server_fns_with_handler("", |func| {
  88. /// move |req: Request<Body>| async move {
  89. /// let (parts, body) = req.into_parts();
  90. /// let parts: Arc<http::request::Parts> = Arc::new(parts.into());
  91. /// let server_context = DioxusServerContext::new(parts.clone());
  92. /// server_fn_handler(server_context, func.clone(), parts, body).await
  93. /// }
  94. /// })
  95. /// .into_make_service(),
  96. /// )
  97. /// .await
  98. /// .unwrap();
  99. /// }
  100. /// ```
  101. fn register_server_fns_with_handler<H, T>(
  102. self,
  103. server_fn_route: &'static str,
  104. handler: impl FnMut(server_fn::ServerFnTraitObj<()>) -> H,
  105. ) -> Self
  106. where
  107. H: Handler<T, S>,
  108. T: 'static,
  109. S: Clone + Send + Sync + 'static;
  110. /// Registers server functions with the default handler. This handler function will pass an empty [`DioxusServerContext`] to your server functions.
  111. ///
  112. /// # Example
  113. /// ```rust
  114. /// use dioxus_lib::prelude::*;
  115. /// use dioxus_fullstack::prelude::*;
  116. ///
  117. /// #[tokio::main]
  118. /// async fn main() {
  119. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  120. /// axum::Server::bind(&addr)
  121. /// .serve(
  122. /// axum::Router::new()
  123. /// // Register server functions routes with the default handler
  124. /// .register_server_fns("")
  125. /// .into_make_service(),
  126. /// )
  127. /// .await
  128. /// .unwrap();
  129. /// }
  130. /// ```
  131. fn register_server_fns(self, server_fn_route: &'static str) -> Self;
  132. /// Register the web RSX hot reloading endpoint. This will enable hot reloading for your application in debug mode when you call [`dioxus_hot_reload::hot_reload_init`].
  133. ///
  134. /// # Example
  135. /// ```rust
  136. /// #![allow(non_snake_case)]
  137. /// use dioxus_fullstack::prelude::*;
  138. ///
  139. /// #[tokio::main]
  140. /// async fn main() {
  141. /// hot_reload_init!();
  142. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  143. /// axum::Server::bind(&addr)
  144. /// .serve(
  145. /// axum::Router::new()
  146. /// // Connect to hot reloading in debug mode
  147. /// .connect_hot_reload()
  148. /// .into_make_service(),
  149. /// )
  150. /// .await
  151. /// .unwrap();
  152. /// }
  153. /// ```
  154. fn connect_hot_reload(self) -> Self;
  155. /// Serves the static WASM for your Dioxus application (except the generated index.html).
  156. ///
  157. /// # Example
  158. /// ```rust
  159. /// #![allow(non_snake_case)]
  160. /// use dioxus_lib::prelude::*;
  161. /// use dioxus_fullstack::prelude::*;
  162. ///
  163. /// #[tokio::main]
  164. /// async fn main() {
  165. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  166. /// axum::Server::bind(&addr)
  167. /// .serve(
  168. /// axum::Router::new()
  169. /// // Server side render the application, serve static assets, and register server functions
  170. /// .serve_static_assets("dist")
  171. /// // Server render the application
  172. /// // ...
  173. /// .into_make_service(),
  174. /// )
  175. /// .await
  176. /// .unwrap();
  177. /// }
  178. ///
  179. /// fn app() -> Element {
  180. /// todo!()
  181. /// }
  182. /// ```
  183. fn serve_static_assets(self, assets_path: impl Into<std::path::PathBuf>) -> Self;
  184. /// Serves the Dioxus application. This will serve a complete server side rendered application.
  185. /// This will serve static assets, server render the application, register server functions, and intigrate with hot reloading.
  186. ///
  187. /// # Example
  188. /// ```rust
  189. /// #![allow(non_snake_case)]
  190. /// use dioxus_lib::prelude::*;
  191. /// use dioxus_fullstack::prelude::*;
  192. ///
  193. /// #[tokio::main]
  194. /// async fn main() {
  195. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  196. /// axum::Server::bind(&addr)
  197. /// .serve(
  198. /// axum::Router::new()
  199. /// // Server side render the application, serve static assets, and register server functions
  200. /// .serve_dioxus_application("", ServerConfig::new(app, ()))
  201. /// .into_make_service(),
  202. /// )
  203. /// .await
  204. /// .unwrap();
  205. /// }
  206. ///
  207. /// fn app() -> Element {
  208. /// todo!()
  209. /// }
  210. /// ```
  211. fn serve_dioxus_application<P: Clone + serde::Serialize + Send + Sync + 'static>(
  212. self,
  213. server_fn_route: &'static str,
  214. cfg: impl Into<ServeConfig<P>>,
  215. ) -> Self;
  216. }
  217. impl<S> DioxusRouterExt<S> for Router<S>
  218. where
  219. S: Send + Sync + Clone + 'static,
  220. {
  221. fn register_server_fns_with_handler<H, T>(
  222. self,
  223. server_fn_route: &'static str,
  224. mut handler: impl FnMut(server_fn::ServerFnTraitObj<()>) -> H,
  225. ) -> Self
  226. where
  227. H: Handler<T, S, Body>,
  228. T: 'static,
  229. S: Clone + Send + Sync + 'static,
  230. {
  231. let mut router = self;
  232. for server_fn_path in DioxusServerFnRegistry::paths_registered() {
  233. let func = DioxusServerFnRegistry::get(server_fn_path).unwrap();
  234. let full_route = format!("{server_fn_route}/{server_fn_path}");
  235. match func.encoding() {
  236. Encoding::Url | Encoding::Cbor => {
  237. router = router.route(&full_route, post(handler(func)));
  238. }
  239. Encoding::GetJSON | Encoding::GetCBOR => {
  240. router = router.route(&full_route, get(handler(func)));
  241. }
  242. }
  243. }
  244. router
  245. }
  246. fn register_server_fns(self, server_fn_route: &'static str) -> Self {
  247. self.register_server_fns_with_handler(server_fn_route, |func| {
  248. move |req: Request<Body>| {
  249. let mut service = crate::server_fn_service(Default::default(), func);
  250. async move {
  251. let (req, body) = req.into_parts();
  252. let req = Request::from_parts(req, body);
  253. let res = service.run(req);
  254. match res.await {
  255. Ok(res) => Ok::<_, std::convert::Infallible>(res.map(|b| b.into())),
  256. Err(e) => {
  257. let mut res = Response::new(Body::from(e.to_string()));
  258. *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
  259. Ok(res)
  260. }
  261. }
  262. }
  263. }
  264. })
  265. }
  266. fn serve_static_assets(mut self, assets_path: impl Into<std::path::PathBuf>) -> Self {
  267. use tower_http::services::{ServeDir, ServeFile};
  268. let assets_path = assets_path.into();
  269. // Serve all files in dist folder except index.html
  270. let dir = std::fs::read_dir(&assets_path).unwrap_or_else(|e| {
  271. panic!(
  272. "Couldn't read assets directory at {:?}: {}",
  273. &assets_path, e
  274. )
  275. });
  276. for entry in dir.flatten() {
  277. let path = entry.path();
  278. if path.ends_with("index.html") {
  279. continue;
  280. }
  281. let route = path
  282. .strip_prefix(&assets_path)
  283. .unwrap()
  284. .iter()
  285. .map(|segment| {
  286. segment.to_str().unwrap_or_else(|| {
  287. panic!("Failed to convert path segment {:?} to string", segment)
  288. })
  289. })
  290. .collect::<Vec<_>>()
  291. .join("/");
  292. let route = format!("/{}", route);
  293. if path.is_dir() {
  294. self = self.nest_service(&route, ServeDir::new(path));
  295. } else {
  296. self = self.nest_service(&route, ServeFile::new(path));
  297. }
  298. }
  299. self
  300. }
  301. fn serve_dioxus_application<P: Clone + serde::Serialize + Send + Sync + 'static>(
  302. self,
  303. server_fn_route: &'static str,
  304. cfg: impl Into<ServeConfig<P>>,
  305. ) -> Self {
  306. let cfg = cfg.into();
  307. let ssr_state = SSRState::new(&cfg);
  308. // Add server functions and render index.html
  309. self.serve_static_assets(cfg.assets_path)
  310. .connect_hot_reload()
  311. .register_server_fns(server_fn_route)
  312. .fallback(get(render_handler).with_state((cfg, ssr_state)))
  313. }
  314. fn connect_hot_reload(self) -> Self {
  315. #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
  316. {
  317. self.nest(
  318. "/_dioxus",
  319. Router::new()
  320. .route(
  321. "/disconnect",
  322. get(|ws: axum::extract::WebSocketUpgrade| async {
  323. ws.on_upgrade(|mut ws| async move {
  324. use axum::extract::ws::Message;
  325. let _ = ws.send(Message::Text("connected".into())).await;
  326. loop {
  327. if ws.recv().await.is_none() {
  328. break;
  329. }
  330. }
  331. })
  332. }),
  333. )
  334. .route("/hot_reload", get(hot_reload_handler)),
  335. )
  336. }
  337. #[cfg(not(all(debug_assertions, feature = "hot-reload", feature = "ssr")))]
  338. {
  339. self
  340. }
  341. }
  342. }
  343. fn apply_request_parts_to_response<B>(
  344. headers: hyper::header::HeaderMap,
  345. response: &mut axum::response::Response<B>,
  346. ) {
  347. let mut_headers = response.headers_mut();
  348. for (key, value) in headers.iter() {
  349. mut_headers.insert(key, value.clone());
  350. }
  351. }
  352. /// SSR renderer handler for Axum with added context injection.
  353. ///
  354. /// # Example
  355. /// ```rust,no_run
  356. /// #![allow(non_snake_case)]
  357. /// use std::sync::{Arc, Mutex};
  358. ///
  359. /// use axum::routing::get;
  360. /// use dioxus_lib::prelude::*;
  361. /// use dioxus_fullstack::{axum_adapter::render_handler_with_context, prelude::*};
  362. ///
  363. /// fn app() -> Element {
  364. /// rsx! {
  365. /// "hello!"
  366. /// }
  367. /// }
  368. ///
  369. /// #[tokio::main]
  370. /// async fn main() {
  371. /// let cfg = ServerConfig::new(app, ())
  372. /// .assets_path("dist")
  373. /// .build();
  374. /// let ssr_state = SSRState::new(&cfg);
  375. ///
  376. /// // This could be any state you want to be accessible from your server
  377. /// // functions using `[DioxusServerContext::get]`.
  378. /// let state = Arc::new(Mutex::new("state".to_string()));
  379. ///
  380. /// let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 8080));
  381. /// axum::Server::bind(&addr)
  382. /// .serve(
  383. /// axum::Router::new()
  384. /// // Register server functions, etc.
  385. /// // Note you probably want to use `register_server_fns_with_handler`
  386. /// // to inject the context into server functions running outside
  387. /// // of an SSR render context.
  388. /// .fallback(get(render_handler_with_context).with_state((
  389. /// move |ctx| ctx.insert(state.clone()).unwrap(),
  390. /// cfg,
  391. /// ssr_state,
  392. /// )))
  393. /// .into_make_service(),
  394. /// )
  395. /// .await
  396. /// .unwrap();
  397. /// }
  398. /// ```
  399. pub async fn render_handler_with_context<
  400. P: Clone + serde::Serialize + Send + Sync + 'static,
  401. F: FnMut(&mut DioxusServerContext),
  402. >(
  403. State((mut inject_context, cfg, ssr_state)): State<(F, ServeConfig<P>, SSRState)>,
  404. request: Request<Body>,
  405. ) -> impl IntoResponse {
  406. let (parts, _) = request.into_parts();
  407. let url = parts.uri.path_and_query().unwrap().to_string();
  408. let parts: Arc<RwLock<http::request::Parts>> = Arc::new(RwLock::new(parts.into()));
  409. let mut server_context = DioxusServerContext::new(parts.clone());
  410. inject_context(&mut server_context);
  411. match ssr_state.render(url, &cfg, &server_context).await {
  412. Ok(rendered) => {
  413. let crate::render::RenderResponse { html, freshness } = rendered;
  414. let mut response = axum::response::Html::from(html).into_response();
  415. freshness.write(response.headers_mut());
  416. let headers = server_context.response_parts().unwrap().headers.clone();
  417. apply_request_parts_to_response(headers, &mut response);
  418. response
  419. }
  420. Err(e) => {
  421. tracing::error!("Failed to render page: {}", e);
  422. report_err(e).into_response()
  423. }
  424. }
  425. }
  426. /// SSR renderer handler for Axum
  427. pub async fn render_handler<P: Clone + serde::Serialize + Send + Sync + 'static>(
  428. State((cfg, ssr_state)): State<(ServeConfig<P>, SSRState)>,
  429. request: Request<Body>,
  430. ) -> impl IntoResponse {
  431. render_handler_with_context(State((|_: &mut _| (), cfg, ssr_state)), request).await
  432. }
  433. fn report_err<E: std::fmt::Display>(e: E) -> Response<BoxBody> {
  434. Response::builder()
  435. .status(StatusCode::INTERNAL_SERVER_ERROR)
  436. .body(body::boxed(format!("Error: {}", e)))
  437. .unwrap()
  438. }
  439. /// A handler for Dioxus web hot reload websocket. This will send the updated static parts of the RSX to the client when they change.
  440. #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
  441. pub async fn hot_reload_handler(ws: axum::extract::WebSocketUpgrade) -> impl IntoResponse {
  442. use axum::extract::ws::Message;
  443. use futures_util::StreamExt;
  444. let state = crate::hot_reload::spawn_hot_reload().await;
  445. ws.on_upgrade(move |mut socket| async move {
  446. println!("🔥 Hot Reload WebSocket connected");
  447. {
  448. // update any rsx calls that changed before the websocket connected.
  449. {
  450. println!("🔮 Finding updates since last compile...");
  451. let templates_read = state.templates.read().await;
  452. for template in &*templates_read {
  453. if socket
  454. .send(Message::Text(serde_json::to_string(&template).unwrap()))
  455. .await
  456. .is_err()
  457. {
  458. return;
  459. }
  460. }
  461. }
  462. println!("finished");
  463. }
  464. let mut rx =
  465. tokio_stream::wrappers::WatchStream::from_changes(state.message_receiver.clone());
  466. while let Some(change) = rx.next().await {
  467. if let Some(template) = change {
  468. let template = { serde_json::to_string(&template).unwrap() };
  469. if socket.send(Message::Text(template)).await.is_err() {
  470. break;
  471. };
  472. }
  473. }
  474. })
  475. }