Browse Source

hot reloading intigration

Evan Almloff 2 years ago
parent
commit
7214130c40

+ 7 - 3
packages/server/Cargo.toml

@@ -15,13 +15,13 @@ warp = { version = "0.3.3", optional = true }
 http-body = { version = "0.4.5", optional = true }
 
 # axum
-axum = { version = "0.6.1", optional = true }
+axum = { version = "0.6.1", features = ["ws"], optional = true }
 tower-http = { version = "0.4.0", optional = true, features = ["fs"] }
 hyper = { version = "0.14.25", optional = true }
 axum-macros = "0.3.7"
 
 # salvo
-salvo = { version = "0.37.7", optional = true, features = ["serve-static"] }
+salvo = { version = "0.37.7", optional = true, features = ["serve-static", "ws"] }
 serde = "1.0.159"
 
 # Dioxus + SSR
@@ -35,12 +35,16 @@ tokio = { version = "1.27.0", features = ["full"], optional = true }
 object-pool = "0.5.4"
 anymap = "0.12.1"
 
+serde_json = { version = "1.0.95", optional = true }
+tokio-stream = { version = "0.1.12", features = ["sync"], optional = true }
+futures-util = { version = "0.3.28", optional = true }
+
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
 dioxus-hot-reload = { path = "../hot-reload" }
 
 [features]
 default = ["hot-reload"]
-hot-reload = []
+hot-reload = ["serde_json", "tokio-stream", "futures-util"]
 warp = ["dep:warp", "http-body", "ssr"]
 axum = ["dep:axum", "tower-http", "hyper", "ssr"]
 salvo = ["dep:salvo", "hyper", "ssr"]

+ 65 - 6
packages/server/src/adapters/axum_adapter.rs

@@ -2,7 +2,7 @@ use std::{error::Error, sync::Arc};
 
 use axum::{
     body::{self, Body, BoxBody, Full},
-    extract::State,
+    extract::{State, WebSocketUpgrade},
     handler::Handler,
     http::{HeaderMap, Request, Response, StatusCode},
     response::IntoResponse,
@@ -36,6 +36,8 @@ pub trait DioxusRouterExt<S> {
         server_fn_route: &'static str,
         cfg: impl Into<ServeConfig<P>>,
     ) -> Self;
+
+    fn connect_hot_reload(self) -> Self;
 }
 
 impl<S> DioxusRouterExt<S> for Router<S>
@@ -92,7 +94,7 @@ where
                 continue;
             }
             let route = path
-                .strip_prefix(&cfg.assets_path)
+                .strip_prefix(cfg.assets_path)
                 .unwrap()
                 .iter()
                 .map(|segment| {
@@ -111,10 +113,26 @@ where
         }
 
         // Add server functions and render index.html
-        self.register_server_fns(server_fn_route).route(
-            "/",
-            get(render_handler).with_state((cfg, SSRState::default())),
-        )
+        self.connect_hot_reload()
+            .register_server_fns(server_fn_route)
+            .route(
+                "/",
+                get(render_handler).with_state((cfg, SSRState::default())),
+            )
+    }
+
+    fn connect_hot_reload(self) -> Self {
+        #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+        {
+            self.route(
+                "/_dioxus/hot_reload",
+                get(hot_reload_handler).with_state(crate::hot_reload::HotReloadState::default()),
+            )
+        }
+        #[cfg(not(all(debug_assertions, feature = "hot-reload", feature = "ssr")))]
+        {
+            self
+        }
     }
 }
 
@@ -195,3 +213,44 @@ fn report_err<E: Error>(e: E) -> Response<BoxBody> {
         .body(body::boxed(format!("Error: {}", e)))
         .unwrap()
 }
+
+#[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+pub async fn hot_reload_handler(
+    ws: WebSocketUpgrade,
+    State(state): State<crate::hot_reload::HotReloadState>,
+) -> impl IntoResponse {
+    use axum::extract::ws::Message;
+    use futures_util::StreamExt;
+
+    ws.on_upgrade(|mut socket| async move {
+        println!("🔥 Hot Reload WebSocket connected");
+        {
+            // update any rsx calls that changed before the websocket connected.
+            {
+                println!("🔮 Finding updates since last compile...");
+                let templates_read = state.templates.read().await;
+
+                for template in &*templates_read {
+                    if socket
+                        .send(Message::Text(serde_json::to_string(&template).unwrap()))
+                        .await
+                        .is_err()
+                    {
+                        return;
+                    }
+                }
+            }
+            println!("finished");
+        }
+
+        let mut rx = tokio_stream::wrappers::WatchStream::from_changes(state.message_receiver);
+        while let Some(change) = rx.next().await {
+            if let Some(template) = change {
+                let template = { serde_json::to_string(&template).unwrap() };
+                if socket.send(Message::Text(template)).await.is_err() {
+                    break;
+                };
+            }
+        }
+    })
+}

+ 86 - 1
packages/server/src/adapters/salvo_adapter.rs

@@ -17,6 +17,7 @@ use crate::{
 
 pub trait DioxusRouterExt {
     fn register_server_fns(self, server_fn_route: &'static str) -> Self;
+
     fn register_server_fns_with_handler<H>(
         self,
         server_fn_route: &'static str,
@@ -24,11 +25,14 @@ pub trait DioxusRouterExt {
     ) -> Self
     where
         H: Handler + 'static;
+
     fn serve_dioxus_application<P: Clone + Send + Sync + 'static>(
         self,
         server_fn_path: &'static str,
         cfg: impl Into<ServeConfig<P>>,
     ) -> Self;
+
+    fn connect_hot_reload(self) -> Self;
 }
 
 impl DioxusRouterExt for Router {
@@ -92,9 +96,14 @@ impl DioxusRouterExt for Router {
             self = self.push(Router::with_path(route).get(serve_dir))
         }
 
-        self.register_server_fns(server_fn_route)
+        self.connect_hot_reload()
+            .register_server_fns(server_fn_route)
             .push(Router::with_path("/").get(SSRHandler { cfg }))
     }
+
+    fn connect_hot_reload(self) -> Self {
+        self.push(Router::with_path("/_dioxus/hot_reload").get(HotReloadHandler::default()))
+    }
 }
 
 struct SSRHandler<P: Clone> {
@@ -217,3 +226,79 @@ fn handle_error(error: impl Error + Send + Sync, res: &mut Response) {
     resp_err.render(format!("Internal Server Error: {}", error));
     *res = resp_err;
 }
+
+#[cfg(not(all(debug_assertions, feature = "hot-reload", feature = "ssr")))]
+#[derive(Default)]
+pub struct HotReloadHandler;
+
+#[cfg(not(all(debug_assertions, feature = "hot-reload", feature = "ssr")))]
+#[handler]
+impl HotReloadHandler {
+    async fn handle(
+        &self,
+        _req: &mut Request,
+        _depot: &mut Depot,
+        _res: &mut Response,
+    ) -> Result<(), salvo::http::StatusError> {
+        Err(salvo::http::StatusError::not_found())
+    }
+}
+
+#[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+#[derive(Default)]
+pub struct HotReloadHandler {
+    state: crate::hot_reload::HotReloadState,
+}
+
+#[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+#[handler]
+impl HotReloadHandler {
+    async fn handle(
+        &self,
+        req: &mut Request,
+        _depot: &mut Depot,
+        res: &mut Response,
+    ) -> Result<(), salvo::http::StatusError> {
+        use salvo::ws::Message;
+        use salvo::ws::WebSocketUpgrade;
+
+        let state = self.state.clone();
+
+        WebSocketUpgrade::new()
+            .upgrade(req, res, |mut websocket| async move {
+                use futures_util::StreamExt;
+
+                println!("🔥 Hot Reload WebSocket connected");
+                {
+                    // update any rsx calls that changed before the websocket connected.
+                    {
+                        println!("🔮 Finding updates since last compile...");
+                        let templates_read = state.templates.read().await;
+
+                        for template in &*templates_read {
+                            if websocket
+                                .send(Message::text(serde_json::to_string(&template).unwrap()))
+                                .await
+                                .is_err()
+                            {
+                                return;
+                            }
+                        }
+                    }
+                    println!("finished");
+                }
+
+                let mut rx =
+                    tokio_stream::wrappers::WatchStream::from_changes(state.message_receiver);
+                while let Some(change) = rx.next().await {
+                    if let Some(template) = change {
+                        let template = { serde_json::to_string(&template).unwrap() };
+                        if websocket.send(Message::text(template)).await.is_err() {
+                            break;
+                        };
+                    }
+                }
+            })
+            .await
+    }
+}

+ 68 - 1
packages/server/src/adapters/warp_adapter.rs

@@ -64,7 +64,8 @@ pub fn serve_dioxus_application<P: Clone + Send + Sync + 'static>(
     // Serve the dist folder and the index.html file
     let serve_dir = warp::fs::dir(cfg.assets_path);
 
-    register_server_fns(server_fn_route)
+    connect_hot_reload()
+        .or(register_server_fns(server_fn_route))
         .or(warp::path::end()
             .and(warp::get())
             .and(with_ssr_state())
@@ -159,3 +160,69 @@ fn report_err<E: Error>(e: E) -> Box<dyn warp::Reply> {
             .unwrap(),
     ) as Box<dyn warp::Reply>
 }
+
+pub fn connect_hot_reload() -> impl Filter<Extract = (impl Reply,), Error = warp::Rejection> {
+    #[cfg(not(all(debug_assertions, feature = "hot-reload", feature = "ssr")))]
+    {
+        warp::path("_dioxus/hot_reload").and(warp::ws()).map(|| {
+            Response::builder()
+                .status(StatusCode::NOT_FOUND)
+                .body("Not Found".into())
+                .unwrap()
+        })
+    }
+    #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+    {
+        use crate::hot_reload::HotReloadState;
+        let state = HotReloadState::default();
+
+        warp::path("_dioxus")
+            .and(warp::path("hot_reload"))
+            .and(warp::ws())
+            .and(warp::any().map(move || state.clone()))
+            .map(move |ws: warp::ws::Ws, state: HotReloadState| {
+                #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+                ws.on_upgrade(move |mut websocket| {
+                    async move {
+                        use futures_util::sink::SinkExt;
+                        use futures_util::StreamExt;
+                        use warp::ws::Message;
+
+                        println!("🔥 Hot Reload WebSocket connected");
+                        {
+                            // update any rsx calls that changed before the websocket connected.
+                            {
+                                println!("🔮 Finding updates since last compile...");
+                                let templates_read = state.templates.read().await;
+
+                                for template in &*templates_read {
+                                    if websocket
+                                        .send(Message::text(
+                                            serde_json::to_string(&template).unwrap(),
+                                        ))
+                                        .await
+                                        .is_err()
+                                    {
+                                        return;
+                                    }
+                                }
+                            }
+                            println!("finished");
+                        }
+
+                        let mut rx = tokio_stream::wrappers::WatchStream::from_changes(
+                            state.message_receiver,
+                        );
+                        while let Some(change) = rx.next().await {
+                            if let Some(template) = change {
+                                let template = { serde_json::to_string(&template).unwrap() };
+                                if websocket.send(Message::text(template)).await.is_err() {
+                                    break;
+                                };
+                            }
+                        }
+                    }
+                })
+            })
+    }
+}

+ 46 - 0
packages/server/src/hot_reload.rs

@@ -0,0 +1,46 @@
+use std::sync::Arc;
+
+use dioxus_core::Template;
+use tokio::sync::{
+    watch::{channel, Receiver},
+    RwLock,
+};
+
+#[derive(Clone)]
+pub struct HotReloadState {
+    // The cache of all templates that have been modified since the last time we checked
+    pub(crate) templates: Arc<RwLock<std::collections::HashSet<dioxus_core::Template<'static>>>>,
+    // The channel to send messages to the hot reload thread
+    pub(crate) message_receiver: Receiver<Option<Template<'static>>>,
+}
+
+impl Default for HotReloadState {
+    fn default() -> Self {
+        let templates = Arc::new(RwLock::new(std::collections::HashSet::new()));
+        let (tx, rx) = channel(None);
+
+        dioxus_hot_reload::connect({
+            let templates = templates.clone();
+            move |msg| match msg {
+                dioxus_hot_reload::HotReloadMsg::UpdateTemplate(template) => {
+                    {
+                        let mut templates = templates.blocking_write();
+                        templates.insert(template);
+                    }
+
+                    if let Err(err) = tx.send(Some(template)) {
+                        log::error!("Failed to send hot reload message: {}", err);
+                    }
+                }
+                dioxus_hot_reload::HotReloadMsg::Shutdown => {
+                    std::process::exit(0);
+                }
+            }
+        });
+
+        Self {
+            templates,
+            message_receiver: rx,
+        }
+    }
+}

+ 2 - 0
packages/server/src/lib.rs

@@ -2,6 +2,8 @@
 use dioxus_core::prelude::*;
 
 mod adapters;
+#[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
+mod hot_reload;
 #[cfg(feature = "ssr")]
 pub mod render;
 #[cfg(feature = "ssr")]

+ 0 - 24
packages/server/src/render.rs

@@ -9,36 +9,12 @@ use crate::prelude::ServeConfig;
 pub struct SSRState {
     // We keep a cache of renderers to avoid re-creating them on every request. They are boxed to make them very cheap to move
     renderers: Arc<object_pool::Pool<Renderer>>,
-    #[cfg(all(debug_assertions, feature = "hot-reload"))]
-    // The cache of all templates that have been modified since the last time we checked
-    templates: Arc<std::sync::RwLock<std::collections::HashSet<dioxus_core::Template<'static>>>>,
 }
 
 impl Default for SSRState {
     fn default() -> Self {
-        #[cfg(all(debug_assertions, feature = "hot-reload"))]
-        let templates = {
-            let templates = Arc::new(std::sync::RwLock::new(std::collections::HashSet::new()));
-            dioxus_hot_reload::connect({
-                let templates = templates.clone();
-                move |msg| match msg {
-                    dioxus_hot_reload::HotReloadMsg::UpdateTemplate(template) => {
-                        if let Ok(mut templates) = templates.write() {
-                            templates.insert(template);
-                        }
-                    }
-                    dioxus_hot_reload::HotReloadMsg::Shutdown => {
-                        std::process::exit(0);
-                    }
-                }
-            });
-            templates
-        };
-
         Self {
             renderers: Arc::new(object_pool::Pool::new(10, pre_renderer)),
-            #[cfg(all(debug_assertions, feature = "hot-reload"))]
-            templates,
         }
     }
 }