Forráskód Böngészése

Fix websocket server functions and add an example (#4107)

* Fix websocket server function

* Add websocket test

* add a websocket example

* fix clippy
Evan Almloff 1 hónapja
szülő
commit
6b70ba7cd2

+ 11 - 0
Cargo.lock

@@ -4334,6 +4334,7 @@ name = "dioxus-playwright-fullstack-test"
 version = "0.1.0"
 dependencies = [
  "dioxus",
+ "futures",
  "serde",
  "tokio",
 ]
@@ -4469,6 +4470,7 @@ dependencies = [
  "dioxus-lib",
  "dioxus-router",
  "dioxus-ssr",
+ "enumset",
  "futures-channel",
  "futures-util",
  "generational-box",
@@ -5565,6 +5567,15 @@ dependencies = [
  "tokio",
 ]
 
+[[package]]
+name = "fullstack-websocket-example"
+version = "0.1.0"
+dependencies = [
+ "dioxus",
+ "futures",
+ "tokio",
+]
+
 [[package]]
 name = "funty"
 version = "2.0.0"

+ 1 - 0
Cargo.toml

@@ -109,6 +109,7 @@ members = [
     "examples/fullstack-streaming",
     "examples/fullstack-desktop",
     "examples/fullstack-auth",
+    "examples/fullstack-websockets",
 
     # Playwright tests
     "packages/playwright-tests/liveview",

+ 4 - 0
examples/fullstack-websockets/.gitignore

@@ -0,0 +1,4 @@
+dist
+target
+static
+.dioxus

+ 16 - 0
examples/fullstack-websockets/Cargo.toml

@@ -0,0 +1,16 @@
+[package]
+name = "fullstack-websocket-example"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+dioxus = { workspace = true, features = ["fullstack"] }
+futures.workspace = true
+tokio = { workspace = true, features = ["full"], optional = true }
+
+[features]
+server = ["dioxus/server", "dep:tokio"]
+web = ["dioxus/web"]

+ 69 - 0
examples/fullstack-websockets/src/main.rs

@@ -0,0 +1,69 @@
+#![allow(non_snake_case)]
+use dioxus::prelude::{
+    server_fn::{codec::JsonEncoding, BoxedStream, Websocket},
+    *,
+};
+use futures::{channel::mpsc, SinkExt, StreamExt};
+
+fn main() {
+    launch(app);
+}
+
+fn app() -> Element {
+    let mut uppercase = use_signal(String::new);
+    let mut uppercase_channel = use_signal(|| None);
+
+    // Start the websocket connection in a background task
+    use_future(move || async move {
+        let (tx, rx) = mpsc::channel(1);
+        let mut receiver = uppercase_ws(rx.into()).await.unwrap();
+        // Store the channel in a signal for use in the input handler
+        uppercase_channel.set(Some(tx));
+        // Whenever we get a message from the server, update the uppercase signal
+        while let Some(Ok(msg)) = receiver.next().await {
+            uppercase.set(msg);
+        }
+    });
+
+    rsx! {
+        input {
+            oninput: move |e| async move {
+                if let Some(mut uppercase_channel) = uppercase_channel() {
+                    let msg = e.value();
+                    uppercase_channel.send(Ok(msg)).await.unwrap();
+                }
+            },
+        }
+        "Uppercase: {uppercase}"
+    }
+}
+
+// The server macro accepts a protocol parameter which implements the protocol trait. The protocol
+// controls how the inputs and outputs are encoded when handling the server function. In this case,
+// the websocket<json, json> protocol can encode a stream input and stream output where messages are
+// serialized as JSON
+#[server(protocol = Websocket<JsonEncoding, JsonEncoding>)]
+async fn uppercase_ws(
+    input: BoxedStream<String, ServerFnError>,
+) -> Result<BoxedStream<String, ServerFnError>, ServerFnError> {
+    let mut input = input;
+
+    // Create a channel with the output of the websocket
+    let (mut tx, rx) = mpsc::channel(1);
+
+    // Spawn a task that processes the input stream and sends any new messages to the output
+    tokio::spawn(async move {
+        while let Some(msg) = input.next().await {
+            if tx
+                .send(msg.map(|msg| msg.to_ascii_uppercase()))
+                .await
+                .is_err()
+            {
+                break;
+            }
+        }
+    });
+
+    // Return the output stream
+    Ok(rx.into())
+}

+ 8 - 0
packages/playwright-tests/fullstack.spec.js

@@ -68,3 +68,11 @@ test("document elements", async ({ page }) => {
   const main = page.locator("#main");
   await expect(main).toHaveCSS("font-family", "Roboto");
 });
+
+test("websockets", async ({ page }) => {
+  await page.goto("http://localhost:3333");
+  // wait until the websocket div is mounted
+  const wsDiv = page.locator("div#websocket-div");
+  await expect(wsDiv).toHaveText("Received: HELLO WORLD");
+});
+

+ 1 - 0
packages/playwright-tests/fullstack/Cargo.toml

@@ -8,6 +8,7 @@ publish = false
 
 [dependencies]
 dioxus = { workspace = true, features = ["fullstack"] }
+futures.workspace = true
 serde = "1.0.218"
 tokio = { workspace = true, features = ["full"], optional = true }
 

+ 48 - 1
packages/playwright-tests/fullstack/src/main.rs

@@ -5,7 +5,14 @@
 // - Hydration
 
 #![allow(non_snake_case)]
-use dioxus::{prelude::*, CapturedError};
+use dioxus::{
+    prelude::{
+        server_fn::{codec::JsonEncoding, BoxedStream, Websocket},
+        *,
+    },
+    CapturedError,
+};
+use futures::{channel::mpsc, SinkExt, StreamExt};
 
 fn main() {
     dioxus::LaunchBuilder::new()
@@ -43,6 +50,7 @@ fn app() -> Element {
         OnMounted {}
         DefaultServerFnCodec {}
         DocumentElements {}
+        WebSockets {}
     }
 }
 
@@ -156,3 +164,42 @@ fn DocumentElements() -> Element {
         document::Style { id: "style-head", "body {{ font-family: 'Roboto'; }}" }
     }
 }
+
+#[server(protocol = Websocket<JsonEncoding, JsonEncoding>)]
+async fn echo_ws(
+    input: BoxedStream<String, ServerFnError>,
+) -> Result<BoxedStream<String, ServerFnError>, ServerFnError> {
+    let mut input = input;
+
+    let (mut tx, rx) = mpsc::channel(1);
+
+    tokio::spawn(async move {
+        while let Some(msg) = input.next().await {
+            let _ = tx.send(msg.map(|msg| msg.to_ascii_uppercase())).await;
+        }
+    });
+
+    Ok(rx.into())
+}
+
+/// This component tests websocket server functions
+#[component]
+fn WebSockets() -> Element {
+    let mut received = use_signal(String::new);
+    use_future(move || async move {
+        let (mut tx, rx) = mpsc::channel(1);
+        let mut receiver = echo_ws(rx.into()).await.unwrap();
+        tx.send(Ok("hello world".to_string())).await.unwrap();
+        while let Some(Ok(msg)) = receiver.next().await {
+            println!("Received: {}", msg);
+            received.set(msg);
+        }
+    });
+
+    rsx! {
+        div {
+            id: "websocket-div",
+            "Received: {received}"
+        }
+    }
+}

+ 1 - 0
packages/server/Cargo.toml

@@ -42,6 +42,7 @@ tracing-futures = { workspace = true }
 once_cell = { workspace = true }
 async-trait = { workspace = true }
 serde = { workspace = true }
+enumset = "1.1.5"
 
 futures-util = { workspace = true }
 futures-channel = { workspace = true }

+ 138 - 11
packages/server/src/context.rs

@@ -1,10 +1,44 @@
+use enumset::{EnumSet, EnumSetType};
 use parking_lot::RwLock;
 use std::any::Any;
 use std::collections::HashMap;
+use std::sync::atomic::AtomicU32;
 use std::sync::Arc;
 
 type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
 
+#[derive(EnumSetType)]
+enum ResponsePartsModified {
+    Version,
+    Headers,
+    Status,
+    Extensions,
+    Body,
+}
+
+struct AtomicResponsePartsModified {
+    modified: AtomicU32,
+}
+
+impl AtomicResponsePartsModified {
+    fn new() -> Self {
+        Self {
+            modified: AtomicU32::new(EnumSet::<ResponsePartsModified>::empty().as_u32()),
+        }
+    }
+
+    fn set(&self, part: ResponsePartsModified) {
+        let modified =
+            EnumSet::from_u32(self.modified.load(std::sync::atomic::Ordering::Relaxed)) | part;
+        self.modified
+            .store(modified.as_u32(), std::sync::atomic::Ordering::Relaxed);
+    }
+
+    fn is_modified(&self, part: ResponsePartsModified) -> bool {
+        self.modified.load(std::sync::atomic::Ordering::Relaxed) & (1 << part as usize) != 0
+    }
+}
+
 /// A shared context for server functions that contains information about the request and middleware state.
 ///
 /// You should not construct this directly inside components or server functions. Instead use [`server_context()`] to get the server context from the current request.
@@ -24,6 +58,7 @@ type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
 #[derive(Clone)]
 pub struct DioxusServerContext {
     shared_context: Arc<RwLock<SendSyncAnyMap>>,
+    response_parts_modified: Arc<AtomicResponsePartsModified>,
     response_parts: Arc<RwLock<http::response::Parts>>,
     pub(crate) parts: Arc<RwLock<http::request::Parts>>,
     response_sent: Arc<std::sync::atomic::AtomicBool>,
@@ -48,6 +83,7 @@ impl Default for DioxusServerContext {
     fn default() -> Self {
         Self {
             shared_context: Arc::new(RwLock::new(HashMap::new())),
+            response_parts_modified: Arc::new(AtomicResponsePartsModified::new()),
             response_parts: Arc::new(RwLock::new(
                 http::response::Response::new(()).into_parts().0,
             )),
@@ -59,7 +95,7 @@ impl Default for DioxusServerContext {
 
 mod server_fn_impl {
     use super::*;
-    use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
+    use parking_lot::{MappedRwLockWriteGuard, RwLockReadGuard, RwLockWriteGuard};
     use std::any::{Any, TypeId};
 
     impl DioxusServerContext {
@@ -68,6 +104,7 @@ mod server_fn_impl {
             Self {
                 parts: Arc::new(RwLock::new(parts)),
                 shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
+                response_parts_modified: Arc::new(AtomicResponsePartsModified::new()),
                 response_parts: std::sync::Arc::new(RwLock::new(
                     http::response::Response::new(()).into_parts().0,
                 )),
@@ -81,6 +118,7 @@ mod server_fn_impl {
             Self {
                 parts,
                 shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
+                response_parts_modified: Arc::new(AtomicResponsePartsModified::new()),
                 response_parts: std::sync::Arc::new(RwLock::new(
                     http::response::Response::new(()).into_parts().0,
                 )),
@@ -178,7 +216,7 @@ mod server_fn_impl {
             self.response_parts.read()
         }
 
-        /// Get the response parts from the server context
+        /// Get the headers from the server context mutably
         ///
         #[doc = include_str!("../docs/request_origin.md")]
         ///
@@ -189,13 +227,82 @@ mod server_fn_impl {
         /// #[server]
         /// async fn set_headers() -> Result<(), ServerFnError> {
         ///     let server_context = server_context();
-        ///     server_context.response_parts_mut()
-        ///         .headers
+        ///     server_context.headers_mut()
         ///         .insert("Cookie", http::HeaderValue::from_static("dioxus=fullstack"));
         ///     Ok(())
         /// }
         /// ```
-        pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
+        pub fn headers_mut(&self) -> MappedRwLockWriteGuard<'_, http::HeaderMap> {
+            self.response_parts_modified
+                .set(ResponsePartsModified::Headers);
+            RwLockWriteGuard::map(self.response_parts_mut(), |parts| &mut parts.headers)
+        }
+
+        /// Get the status from the server context mutably
+        ///
+        #[doc = include_str!("../docs/request_origin.md")]
+        ///
+        /// # Example
+        ///
+        /// ```rust, no_run
+        /// # use dioxus::prelude::*;
+        /// #[server]
+        /// async fn set_status() -> Result<(), ServerFnError> {
+        ///     let server_context = server_context();
+        ///     *server_context.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
+        ///     Ok(())
+        /// }
+        /// ```
+        pub fn status_mut(&self) -> MappedRwLockWriteGuard<'_, http::StatusCode> {
+            self.response_parts_modified
+                .set(ResponsePartsModified::Status);
+            RwLockWriteGuard::map(self.response_parts_mut(), |parts| &mut parts.status)
+        }
+
+        /// Get the version from the server context mutably
+        ///
+        #[doc = include_str!("../docs/request_origin.md")]
+        ///
+        /// # Example
+        ///
+        /// ```rust, no_run
+        /// # use dioxus::prelude::*;
+        /// #[server]
+        /// async fn set_version() -> Result<(), ServerFnError> {
+        ///     let server_context = server_context();
+        ///     *server_context.version_mut() = http::Version::HTTP_2;
+        ///     Ok(())
+        /// }
+        /// ```
+        pub fn version_mut(&self) -> MappedRwLockWriteGuard<'_, http::Version> {
+            self.response_parts_modified
+                .set(ResponsePartsModified::Version);
+            RwLockWriteGuard::map(self.response_parts_mut(), |parts| &mut parts.version)
+        }
+
+        /// Get the extensions from the server context mutably
+        ///
+        #[doc = include_str!("../docs/request_origin.md")]
+        ///
+        /// # Example
+        ///
+        /// ```rust, no_run
+        /// # use dioxus::prelude::*;
+        /// #[server]
+        /// async fn set_version() -> Result<(), ServerFnError> {
+        ///     let server_context = server_context();
+        ///     *server_context.version_mut() = http::Version::HTTP_2;
+        ///     Ok(())
+        /// }
+        /// ```
+        pub fn extensions_mut(&self) -> MappedRwLockWriteGuard<'_, http::Extensions> {
+            self.response_parts_modified
+                .set(ResponsePartsModified::Extensions);
+            RwLockWriteGuard::map(self.response_parts_mut(), |parts| &mut parts.extensions)
+        }
+
+        /// Get the response parts mutably. This does not track what parts have been written to so it should not be exposed publicly.
+        fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
             if self
                 .response_sent
                 .load(std::sync::atomic::Ordering::Relaxed)
@@ -280,13 +387,33 @@ mod server_fn_impl {
                 .store(true, std::sync::atomic::Ordering::Relaxed);
             let parts = self.response_parts.read();
 
-            let mut_headers = response.headers_mut();
-            for (key, value) in parts.headers.iter() {
-                mut_headers.insert(key, value.clone());
+            if self
+                .response_parts_modified
+                .is_modified(ResponsePartsModified::Headers)
+            {
+                let mut_headers = response.headers_mut();
+                for (key, value) in parts.headers.iter() {
+                    mut_headers.insert(key, value.clone());
+                }
+            }
+            if self
+                .response_parts_modified
+                .is_modified(ResponsePartsModified::Status)
+            {
+                *response.status_mut() = parts.status;
+            }
+            if self
+                .response_parts_modified
+                .is_modified(ResponsePartsModified::Version)
+            {
+                *response.version_mut() = parts.version;
+            }
+            if self
+                .response_parts_modified
+                .is_modified(ResponsePartsModified::Extensions)
+            {
+                response.extensions_mut().extend(parts.extensions.clone());
             }
-            *response.status_mut() = parts.status;
-            *response.version_mut() = parts.version;
-            response.extensions_mut().extend(parts.extensions.clone());
         }
     }
 }