1
0
Эх сурвалжийг харах

Merge pull request #2005 from ealmloff/fix-suspense

Only poll suspended futures, lazy memos
Jonathan Kelley 1 жил өмнө
parent
commit
295c29db5d

+ 0 - 1
Cargo.lock

@@ -2744,7 +2744,6 @@ version = "0.5.0-alpha.0"
 dependencies = [
  "dioxus",
  "dioxus-core",
- "flume",
  "futures-channel",
  "futures-util",
  "generational-box",

+ 8 - 12
examples/memo_chain.rs

@@ -21,29 +21,25 @@ fn app() -> Element {
         button { onclick: move |_| value += 1, "Increment" }
         button { onclick: move |_| depth += 1, "Add depth" }
         button { onclick: move |_| depth -= 1, "Remove depth" }
-        Child { depth, items, state }
+        if depth() > 0 {
+            Child { depth, items, state }
+        }
     }
 }
 
 #[component]
-fn Child(
-    state: ReadOnlySignal<isize>,
-    items: ReadOnlySignal<Vec<isize>>,
-    depth: ReadOnlySignal<usize>,
-) -> Element {
-    if depth() == 0 {
-        return None;
-    }
-
+fn Child(state: Memo<isize>, items: Memo<Vec<isize>>, depth: ReadOnlySignal<usize>) -> Element {
     // These memos don't get re-computed when early returns happen
     let state = use_memo(move || state() + 1);
-    let item = use_memo(move || items()[depth()]);
+    let item = use_memo(move || items()[depth() - 1]);
     let depth = use_memo(move || depth() - 1);
 
     println!("rendering child: {}", depth());
 
     rsx! {
         h3 { "Depth({depth})-Item({item}): {state}"}
-        Child { depth, state, items }
+        if depth() > 0 {
+            Child { depth, state, items }
+        }
     }
 }

+ 47 - 30
packages/core/src/dirty_scope.rs

@@ -29,9 +29,9 @@
 
 use crate::ScopeId;
 use crate::Task;
+use crate::VirtualDom;
 use std::borrow::Borrow;
 use std::cell::RefCell;
-use std::collections::BTreeSet;
 use std::hash::Hash;
 
 #[derive(Debug, Clone, Copy, Eq)]
@@ -70,50 +70,71 @@ impl Hash for ScopeOrder {
     }
 }
 
-#[derive(Debug, Default)]
-pub struct DirtyScopes {
-    pub(crate) scopes: BTreeSet<ScopeOrder>,
-    pub(crate) tasks: BTreeSet<DirtyTasks>,
-}
-
-impl DirtyScopes {
+impl VirtualDom {
     /// Queue a task to be polled
-    pub fn queue_task(&mut self, task: Task, order: ScopeOrder) {
-        match self.tasks.get(&order) {
+    pub(crate) fn queue_task(&mut self, task: Task, order: ScopeOrder) {
+        match self.dirty_tasks.get(&order) {
             Some(scope) => scope.queue_task(task),
             None => {
                 let scope = DirtyTasks::from(order);
                 scope.queue_task(task);
-                self.tasks.insert(scope);
+                self.dirty_tasks.insert(scope);
             }
         }
     }
 
     /// Queue a scope to be rerendered
-    pub fn queue_scope(&mut self, order: ScopeOrder) {
-        self.scopes.insert(order);
+    pub(crate) fn queue_scope(&mut self, order: ScopeOrder) {
+        self.dirty_scopes.insert(order);
     }
 
     /// Check if there are any dirty scopes
-    pub fn has_dirty_scopes(&self) -> bool {
-        !self.scopes.is_empty()
+    pub(crate) fn has_dirty_scopes(&self) -> bool {
+        !self.dirty_scopes.is_empty()
     }
 
     /// Take any tasks from the highest scope
-    pub fn pop_task(&mut self) -> Option<DirtyTasks> {
-        self.tasks.pop_first()
+    pub(crate) fn pop_task(&mut self) -> Option<DirtyTasks> {
+        let mut task = self.dirty_tasks.pop_first()?;
+
+        // If the scope doesn't exist for whatever reason, then we should skip it
+        while !self.scopes.contains(task.order.id.0) {
+            task = self.dirty_tasks.pop_first()?;
+        }
+
+        Some(task)
     }
 
     /// Take any work from the highest scope. This may include rerunning the scope and/or running tasks
-    pub fn pop_work(&mut self) -> Option<Work> {
-        let dirty_scope = self.scopes.first();
-        let dirty_task = self.tasks.first();
+    pub(crate) fn pop_work(&mut self) -> Option<Work> {
+        let mut dirty_scope = self.dirty_scopes.first();
+        // Pop any invalid scopes off of each dirty task;
+        while let Some(scope) = dirty_scope {
+            if !self.scopes.contains(scope.id.0) {
+                self.dirty_scopes.pop_first();
+                dirty_scope = self.dirty_scopes.first();
+            } else {
+                break;
+            }
+        }
+
+        let mut dirty_task = self.dirty_tasks.first();
+        // Pop any invalid tasks off of each dirty scope;
+        while let Some(task) = dirty_task {
+            if !self.scopes.contains(task.order.id.0) {
+                self.dirty_tasks.pop_first();
+                dirty_task = self.dirty_tasks.first();
+            } else {
+                break;
+            }
+        }
+
         match (dirty_scope, dirty_task) {
             (Some(scope), Some(task)) => {
                 let tasks_order = task.borrow();
                 match scope.cmp(tasks_order) {
                     std::cmp::Ordering::Less => {
-                        let scope = self.scopes.pop_first().unwrap();
+                        let scope = self.dirty_scopes.pop_first().unwrap();
                         Some(Work {
                             scope,
                             rerun_scope: true,
@@ -121,7 +142,7 @@ impl DirtyScopes {
                         })
                     }
                     std::cmp::Ordering::Greater => {
-                        let task = self.tasks.pop_first().unwrap();
+                        let task = self.dirty_tasks.pop_first().unwrap();
                         Some(Work {
                             scope: task.order,
                             rerun_scope: false,
@@ -129,8 +150,8 @@ impl DirtyScopes {
                         })
                     }
                     std::cmp::Ordering::Equal => {
-                        let scope = self.scopes.pop_first().unwrap();
-                        let task = self.tasks.pop_first().unwrap();
+                        let scope = self.dirty_scopes.pop_first().unwrap();
+                        let task = self.dirty_tasks.pop_first().unwrap();
                         Some(Work {
                             scope,
                             rerun_scope: true,
@@ -140,7 +161,7 @@ impl DirtyScopes {
                 }
             }
             (Some(_), None) => {
-                let scope = self.scopes.pop_first().unwrap();
+                let scope = self.dirty_scopes.pop_first().unwrap();
                 Some(Work {
                     scope,
                     rerun_scope: true,
@@ -148,7 +169,7 @@ impl DirtyScopes {
                 })
             }
             (None, Some(_)) => {
-                let task = self.tasks.pop_first().unwrap();
+                let task = self.dirty_tasks.pop_first().unwrap();
                 Some(Work {
                     scope: task.order,
                     rerun_scope: false,
@@ -158,10 +179,6 @@ impl DirtyScopes {
             (None, None) => None,
         }
     }
-
-    pub fn remove(&mut self, scope: &ScopeOrder) {
-        self.scopes.remove(scope);
-    }
 }
 
 #[derive(Debug)]

+ 3 - 3
packages/core/src/global_context.rs

@@ -50,9 +50,9 @@ pub fn provide_root_context<T: 'static + Clone>(value: T) -> T {
         .expect("to be in a dioxus runtime")
 }
 
-/// Suspends the current component
-pub fn suspend() -> Option<Element> {
-    Runtime::with_current_scope(|cx| cx.suspend());
+/// Suspended the current component on a specific task and then return None
+pub fn suspend(task: Task) -> Element {
+    Runtime::with_current_scope(|cx| cx.suspend(task));
     None
 }
 

+ 8 - 2
packages/core/src/runtime.rs

@@ -1,3 +1,5 @@
+use rustc_hash::FxHashSet;
+
 use crate::{
     innerlude::{LocalTask, SchedulerMsg},
     render_signal::RenderSignal,
@@ -24,11 +26,14 @@ pub struct Runtime {
     // We use this to track the current task
     pub(crate) current_task: Cell<Option<Task>>,
 
-    pub(crate) rendering: Cell<bool>,
-
     /// Tasks created with cx.spawn
     pub(crate) tasks: RefCell<slab::Slab<Rc<LocalTask>>>,
 
+    // Currently suspended tasks
+    pub(crate) suspended_tasks: RefCell<FxHashSet<Task>>,
+
+    pub(crate) rendering: Cell<bool>,
+
     pub(crate) sender: futures_channel::mpsc::UnboundedSender<SchedulerMsg>,
 
     // Synchronous tasks need to be run after the next render. The virtual dom stores a list of those tasks to send a signal to them when the next render is done.
@@ -45,6 +50,7 @@ impl Runtime {
             scope_stack: Default::default(),
             current_task: Default::default(),
             tasks: Default::default(),
+            suspended_tasks: Default::default(),
         })
     }
 

+ 3 - 5
packages/core/src/scope_arena.rs

@@ -41,7 +41,6 @@ impl VirtualDom {
         let new_nodes = {
             let context = scope.state();
 
-            context.suspended.set(false);
             context.hook_index.set(0);
 
             // Run all pre-render hooks
@@ -70,12 +69,11 @@ impl VirtualDom {
         self.dirty_scopes
             .remove(&ScopeOrder::new(context.height, scope_id));
 
-        if context.suspended.get() {
+        if let Some(task) = context.last_suspendable_task.take() {
             if matches!(new_nodes, RenderReturn::Aborted(_)) {
-                self.suspended_scopes.insert(context.id);
+                tracing::trace!("Suspending {:?} on {:?}", scope_id, task);
+                self.runtime.suspended_tasks.borrow_mut().insert(task);
             }
-        } else if !self.suspended_scopes.is_empty() {
-            _ = self.suspended_scopes.remove(&context.id);
         }
 
         self.runtime.scope_stack.borrow_mut().pop();

+ 9 - 8
packages/core/src/scope_context.rs

@@ -16,13 +16,14 @@ pub(crate) struct Scope {
     pub(crate) parent_id: Option<ScopeId>,
     pub(crate) height: u32,
     pub(crate) render_count: Cell<usize>,
-    pub(crate) suspended: Cell<bool>,
 
     // Note: the order of the hook and context fields is important. The hooks field must be dropped before the contexts field in case a hook drop implementation tries to access a context.
     pub(crate) hooks: RefCell<Vec<Box<dyn Any>>>,
     pub(crate) hook_index: Cell<usize>,
     pub(crate) shared_contexts: RefCell<Vec<Box<dyn Any>>>,
     pub(crate) spawned_tasks: RefCell<FxHashSet<Task>>,
+    /// The task that was last spawned that may suspend. We use this task to check what task to suspend in the event of an early None return from a component
+    pub(crate) last_suspendable_task: Cell<Option<Task>>,
     pub(crate) before_render: RefCell<Vec<Box<dyn FnMut()>>>,
     pub(crate) after_render: RefCell<Vec<Box<dyn FnMut()>>>,
 }
@@ -40,9 +41,9 @@ impl Scope {
             parent_id,
             height,
             render_count: Cell::new(0),
-            suspended: Cell::new(false),
             shared_contexts: RefCell::new(vec![]),
             spawned_tasks: RefCell::new(FxHashSet::default()),
+            last_suspendable_task: Cell::new(None),
             hooks: RefCell::new(vec![]),
             hook_index: Cell::new(0),
             before_render: RefCell::new(vec![]),
@@ -241,9 +242,9 @@ impl Scope {
         Runtime::with(|rt| rt.spawn(self.id, fut)).expect("Runtime to exist")
     }
 
-    /// Mark this component as suspended and then return None
-    pub fn suspend(&self) -> Option<Element> {
-        self.suspended.set(true);
+    /// Mark this component as suspended on a specific task and then return None
+    pub fn suspend(&self, task: Task) -> Option<Element> {
+        self.last_suspendable_task.set(Some(task));
         None
     }
 
@@ -340,10 +341,10 @@ impl ScopeId {
             .expect("to be in a dioxus runtime")
     }
 
-    /// Suspends the current component
-    pub fn suspend(self) -> Option<Element> {
+    /// Suspended a component on a specific task and then return None
+    pub fn suspend(self, task: Task) -> Option<Element> {
         Runtime::with_scope(self, |cx| {
-            cx.suspend();
+            cx.suspend(task);
         });
         None
     }

+ 2 - 2
packages/core/src/tasks.rs

@@ -171,8 +171,7 @@ impl Runtime {
                 .borrow_mut()
                 .remove(&id);
 
-            // Remove it from the scheduler
-            self.tasks.borrow_mut().try_remove(id.0);
+            self.remove_task(id);
         }
 
         // Remove the scope from the stack
@@ -187,6 +186,7 @@ impl Runtime {
     ///
     /// This does not abort the task, so you'll want to wrap it in an abort handle if that's important to you
     pub(crate) fn remove_task(&self, id: Task) -> Option<Rc<LocalTask>> {
+        self.suspended_tasks.borrow_mut().remove(&id);
         self.tasks.borrow_mut().try_remove(id.0)
     }
 }

+ 88 - 46
packages/core/src/virtual_dom.rs

@@ -2,14 +2,14 @@
 //!
 //! This module provides the primary mechanics to create a hook-based, concurrent VDOM for Rust.
 
-use crate::innerlude::ScopeOrder;
+use crate::innerlude::{DirtyTasks, ScopeOrder};
 use crate::Task;
 use crate::{
     any_props::AnyProps,
     arena::ElementId,
     innerlude::{
-        DirtyScopes, ElementRef, ErrorBoundary, NoOpMutations, SchedulerMsg, ScopeState,
-        VNodeMount, VProps, WriteMutations,
+        ElementRef, ErrorBoundary, NoOpMutations, SchedulerMsg, ScopeState, VNodeMount, VProps,
+        WriteMutations,
     },
     nodes::RenderReturn,
     nodes::{Template, TemplateId},
@@ -18,8 +18,9 @@ use crate::{
     AttributeValue, ComponentFunction, Element, Event, Mutations,
 };
 use futures_util::StreamExt;
-use rustc_hash::{FxHashMap, FxHashSet};
+use rustc_hash::FxHashMap;
 use slab::Slab;
+use std::collections::BTreeSet;
 use std::{any::Any, rc::Rc};
 use tracing::instrument;
 
@@ -185,7 +186,8 @@ use tracing::instrument;
 pub struct VirtualDom {
     pub(crate) scopes: Slab<ScopeState>,
 
-    pub(crate) dirty_scopes: DirtyScopes,
+    pub(crate) dirty_scopes: BTreeSet<ScopeOrder>,
+    pub(crate) dirty_tasks: BTreeSet<DirtyTasks>,
 
     // Maps a template path to a map of byte indexes to templates
     pub(crate) templates: FxHashMap<TemplateId, FxHashMap<usize, Template>>,
@@ -201,9 +203,6 @@ pub struct VirtualDom {
 
     pub(crate) runtime: Rc<Runtime>,
 
-    // Currently suspended scopes
-    pub(crate) suspended_scopes: FxHashSet<ScopeId>,
-
     rx: futures_channel::mpsc::UnboundedReceiver<SchedulerMsg>,
 }
 
@@ -315,11 +314,11 @@ impl VirtualDom {
             runtime: Runtime::new(tx),
             scopes: Default::default(),
             dirty_scopes: Default::default(),
+            dirty_tasks: Default::default(),
             templates: Default::default(),
             queued_templates: Default::default(),
             elements: Default::default(),
             mounts: Default::default(),
-            suspended_scopes: Default::default(),
         };
 
         let root = dom.new_scope(Box::new(root), "app");
@@ -380,7 +379,8 @@ impl VirtualDom {
 
         tracing::event!(tracing::Level::TRACE, "Marking scope {:?} as dirty", id);
         let order = ScopeOrder::new(scope.height(), id);
-        self.dirty_scopes.queue_scope(order);
+        drop(scope);
+        self.queue_scope(order);
     }
 
     /// Mark a task as dirty
@@ -400,7 +400,8 @@ impl VirtualDom {
         );
 
         let order = ScopeOrder::new(scope.height(), scope.id);
-        self.dirty_scopes.queue_task(task, order);
+        drop(scope);
+        self.queue_task(task, order);
     }
 
     /// Call a listener inside the VirtualDom with data from outside the VirtualDom. **The ElementId passed in must be the id of an element with a listener, not a static node or a text node.**
@@ -448,20 +449,13 @@ impl VirtualDom {
     /// ```
     #[instrument(skip(self), level = "trace", name = "VirtualDom::wait_for_work")]
     pub async fn wait_for_work(&mut self) {
-        // And then poll the futures
-        self.poll_tasks().await;
-    }
-
-    /// Poll the scheduler for any work
-    #[instrument(skip(self), level = "trace", name = "VirtualDom::poll_tasks")]
-    async fn poll_tasks(&mut self) {
         loop {
             // Process all events - Scopes are marked dirty, etc
             // Sometimes when wakers fire we get a slew of updates at once, so its important that we drain this completely
             self.process_events();
 
             // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
-            if self.dirty_scopes.has_dirty_scopes() {
+            if self.has_dirty_scopes() {
                 return;
             }
 
@@ -469,17 +463,22 @@ impl VirtualDom {
             let _runtime = RuntimeGuard::new(self.runtime.clone());
 
             // There isn't any more work we can do synchronously. Wait for any new work to be ready
-            match self.rx.next().await.expect("channel should never close") {
-                SchedulerMsg::Immediate(id) => self.mark_dirty(id),
-                SchedulerMsg::TaskNotified(id) => {
-                    // Instead of running the task immediately, we insert it into the runtime's task queue.
-                    // The task may be marked dirty at the same time as the scope that owns the task is dropped.
-                    self.mark_task_dirty(id);
-                }
-            };
+            self.wait_for_event().await;
         }
     }
 
+    /// Wait for the next event to trigger and add it to the queue
+    async fn wait_for_event(&mut self) {
+        match self.rx.next().await.expect("channel should never close") {
+            SchedulerMsg::Immediate(id) => self.mark_dirty(id),
+            SchedulerMsg::TaskNotified(id) => {
+                // Instead of running the task immediately, we insert it into the runtime's task queue.
+                // The task may be marked dirty at the same time as the scope that owns the task is dropped.
+                self.mark_task_dirty(id);
+            }
+        };
+    }
+
     /// Queue any pending events
     fn queue_events(&mut self) {
         // Prevent a task from deadlocking the runtime by repeatedly queueing itself
@@ -494,29 +493,31 @@ impl VirtualDom {
     /// Process all events in the queue until there are no more left
     #[instrument(skip(self), level = "trace", name = "VirtualDom::process_events")]
     pub fn process_events(&mut self) {
-        let _runtime = RuntimeGuard::new(self.runtime.clone());
         self.queue_events();
 
         // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
-        if self.dirty_scopes.has_dirty_scopes() {
+        if self.has_dirty_scopes() {
             return;
         }
 
+        self.poll_tasks()
+    }
+
+    /// Poll any queued tasks
+    #[instrument(skip(self), level = "trace", name = "VirtualDom::poll_tasks")]
+    fn poll_tasks(&mut self) {
+        // Make sure we set the runtime since we're running user code
+        let _runtime = RuntimeGuard::new(self.runtime.clone());
         // Next, run any queued tasks
         // We choose not to poll the deadline since we complete pretty quickly anyways
-        while let Some(task) = self.dirty_scopes.pop_task() {
-            // If the scope doesn't exist for whatever reason, then we should skip it
-            if !self.scopes.contains(task.order.id.0) {
-                continue;
-            }
-
+        while let Some(task) = self.pop_task() {
             // Then poll any tasks that might be pending
             let tasks = task.tasks_queued.into_inner();
             for task in tasks {
                 let _ = self.runtime.handle_task_wakeup(task);
                 // Running that task, may mark a scope higher up as dirty. If it does, return from the function early
                 self.queue_events();
-                if self.dirty_scopes.has_dirty_scopes() {
+                if self.has_dirty_scopes() {
                     return;
                 }
             }
@@ -608,16 +609,10 @@ impl VirtualDom {
 
         // Next, diff any dirty scopes
         // We choose not to poll the deadline since we complete pretty quickly anyways
-        while let Some(work) = self.dirty_scopes.pop_work() {
-            // If the scope doesn't exist for whatever reason, then we should skip it
-            if !self.scopes.contains(work.scope.id.0) {
-                continue;
-            }
-
+        while let Some(work) = self.pop_work() {
             {
                 let _runtime = RuntimeGuard::new(self.runtime.clone());
                 // Then, poll any tasks that might be pending in the scope
-                // This will run effects, so this **must** be done after the scope is diffed
                 for task in work.tasks {
                     let _ = self.runtime.handle_task_wakeup(task);
                 }
@@ -649,15 +644,62 @@ impl VirtualDom {
     #[instrument(skip(self), level = "trace", name = "VirtualDom::wait_for_suspense")]
     pub async fn wait_for_suspense(&mut self) {
         loop {
-            if self.suspended_scopes.is_empty() {
+            if self.runtime.suspended_tasks.borrow().is_empty() {
                 break;
             }
 
             // Wait for a work to be ready (IE new suspense leaves to pop up)
-            self.poll_tasks().await;
+            'wait_for_work: loop {
+                // Process all events - Scopes are marked dirty, etc
+                // Sometimes when wakers fire we get a slew of updates at once, so its important that we drain this completely
+                self.queue_events();
+
+                // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
+                if self.has_dirty_scopes() {
+                    break;
+                }
+
+                {
+                    // Make sure we set the runtime since we're running user code
+                    let _runtime = RuntimeGuard::new(self.runtime.clone());
+                    // Next, run any queued tasks
+                    // We choose not to poll the deadline since we complete pretty quickly anyways
+                    while let Some(task) = self.pop_task() {
+                        // Then poll any tasks that might be pending
+                        let tasks = task.tasks_queued.into_inner();
+                        for task in tasks {
+                            if self.runtime.suspended_tasks.borrow().contains(&task) {
+                                let _ = self.runtime.handle_task_wakeup(task);
+                                // Running that task, may mark a scope higher up as dirty. If it does, return from the function early
+                                self.queue_events();
+                                if self.has_dirty_scopes() {
+                                    break 'wait_for_work;
+                                }
+                            }
+                        }
+                    }
+                }
+
+                self.wait_for_event().await;
+            }
 
             // Render whatever work needs to be rendered, unlocking new futures and suspense leaves
-            self.render_immediate(&mut NoOpMutations);
+            let _runtime = RuntimeGuard::new(self.runtime.clone());
+            while let Some(work) = self.pop_work() {
+                // Then, poll any tasks that might be pending in the scope
+                for task in work.tasks {
+                    // During suspense, we only want to run tasks that are suspended
+                    if self.runtime.suspended_tasks.borrow().contains(&task) {
+                        let _ = self.runtime.handle_task_wakeup(task);
+                    }
+                }
+                // If the scope is dirty, run the scope and get the mutations
+                if work.rerun_scope {
+                    let new_nodes = self.run_scope(work.scope.id);
+
+                    self.diff_scope(&mut NoOpMutations, work.scope.id, new_nodes);
+                }
+            }
         }
     }
 

+ 28 - 3
packages/core/tests/suspense.rs

@@ -1,9 +1,10 @@
 use dioxus::prelude::*;
+use std::future::poll_fn;
+use std::task::Poll;
 
 #[test]
 fn suspense_resolves() {
     // wait just a moment, not enough time for the boundary to resolve
-
     tokio::runtime::Builder::new_current_thread()
         .build()
         .unwrap()
@@ -31,11 +32,35 @@ fn app() -> Element {
 fn suspended_child() -> Element {
     let mut val = use_signal(|| 0);
 
+    // Tasks that are not suspended should never be polled
+    spawn(async move {
+        panic!("Non-suspended task was polled");
+    });
+
+    // Memos should still work like normal
+    let memo = use_memo(move || val * 2);
+    assert_eq!(memo, val * 2);
+
     if val() < 3 {
-        spawn(async move {
+        let task = spawn(async move {
+            // Poll each task 3 times
+            let mut count = 0;
+            poll_fn(|cx| {
+                println!("polling... {}", count);
+                if count < 3 {
+                    count += 1;
+                    cx.waker().wake_by_ref();
+                    Poll::Pending
+                } else {
+                    Poll::Ready(())
+                }
+            })
+            .await;
+
+            println!("waiting... {}", val);
             val += 1;
         });
-        suspend()?;
+        suspend(task)?;
     }
 
     rsx!("child")

+ 2 - 2
packages/core/tests/task.rs

@@ -53,7 +53,7 @@ async fn running_async() {
 #[tokio::test]
 async fn yield_now_works() {
     thread_local! {
-        static SEQUENCE: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
+        static SEQUENCE: std::cell::RefCell<Vec<usize>> = const { std::cell::RefCell::new(Vec::new()) };
     }
 
     fn app() -> Element {
@@ -88,7 +88,7 @@ async fn yield_now_works() {
 #[tokio::test]
 async fn flushing() {
     thread_local! {
-        static SEQUENCE: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
+        static SEQUENCE: std::cell::RefCell<Vec<usize>> = const { std::cell::RefCell::new(Vec::new()) };
         static BROADCAST: (tokio::sync::broadcast::Sender<()>, tokio::sync::broadcast::Receiver<()>) = tokio::sync::broadcast::channel(1);
     }
 

+ 1 - 1
packages/fullstack/src/hooks/server_future.rs

@@ -54,7 +54,7 @@ where
     // Suspend if the value isn't ready
     match resource.state().cloned() {
         UseResourceState::Pending => {
-            suspend();
+            suspend(resource.task());
             None
         }
         _ => Some(resource),

+ 3 - 2
packages/hooks/src/use_effect.rs

@@ -1,5 +1,6 @@
 use dioxus_core::prelude::*;
 use dioxus_signals::ReactiveContext;
+use futures_util::StreamExt;
 
 /// `use_effect` will subscribe to any changes in the signal values it captures
 /// effects will always run after first mount and then whenever the signal values change
@@ -26,13 +27,13 @@ pub fn use_effect(mut callback: impl FnMut() + 'static) {
 
     use_hook(|| {
         spawn(async move {
-            let rc = ReactiveContext::new_with_origin(location);
+            let (rc, mut changed) = ReactiveContext::new_with_origin(location);
             loop {
                 // Run the effect
                 rc.run_in(&mut callback);
 
                 // Wait for context to change
-                rc.changed().await;
+                let _ = changed.next().await;
 
                 // Wait for the dom the be finished with sync work
                 wait_for_next_render().await;

+ 9 - 48
packages/hooks/src/use_memo.rs

@@ -1,8 +1,10 @@
 use crate::dependency::Dependency;
-use crate::use_signal;
+use crate::{use_callback, use_signal};
 use dioxus_core::prelude::*;
+use dioxus_signals::Memo;
 use dioxus_signals::{ReactiveContext, ReadOnlySignal, Readable, Signal, SignalData};
 use dioxus_signals::{Storage, Writable};
+use futures_util::StreamExt;
 
 /// Creates a new unsync Selector. The selector will be run immediately and whenever any signal it reads changes.
 ///
@@ -22,51 +24,9 @@ use dioxus_signals::{Storage, Writable};
 /// }
 /// ```
 #[track_caller]
-pub fn use_memo<R: PartialEq>(f: impl FnMut() -> R + 'static) -> ReadOnlySignal<R> {
-    use_maybe_sync_memo(f)
-}
-
-/// Creates a new Selector that may be sync. The selector will be run immediately and whenever any signal it reads changes.
-///
-/// Selectors can be used to efficiently compute derived data from signals.
-///
-/// ```rust
-/// use dioxus::prelude::*;
-/// use dioxus_signals::*;
-///
-/// fn App() -> Element {
-///     let mut count = use_signal(|| 0);
-///     let double = use_memo(move || count * 2);
-///     count += 1;
-///     assert_eq!(double(), count * 2);
-///
-///     rsx! { "{double}" }
-/// }
-/// ```
-#[track_caller]
-pub fn use_maybe_sync_memo<R: PartialEq, S: Storage<SignalData<R>>>(
-    mut f: impl FnMut() -> R + 'static,
-) -> ReadOnlySignal<R, S> {
-    use_hook(|| {
-        // Create a new reactive context for the memo
-        let rc = ReactiveContext::new();
-
-        // Create a new signal in that context, wiring up its dependencies and subscribers
-        let mut state: Signal<R, S> = rc.run_in(|| Signal::new_maybe_sync(f()));
-
-        spawn(async move {
-            loop {
-                rc.changed().await;
-                let new = rc.run_in(&mut f);
-                if new != *state.peek() {
-                    *state.write() = new;
-                }
-            }
-        });
-
-        // And just return the readonly variant of that signal
-        ReadOnlySignal::new_maybe_sync(state)
-    })
+pub fn use_memo<R: PartialEq>(f: impl FnMut() -> R + 'static) -> Memo<R> {
+    let mut callback = use_callback(f);
+    use_hook(|| Signal::memo(move || callback.call()))
 }
 
 /// Creates a new unsync Selector with some local dependencies. The selector will be run immediately and whenever any signal it reads or any dependencies it tracks changes
@@ -127,7 +87,7 @@ where
 
     let selector = use_hook(|| {
         // Get the current reactive context
-        let rc = ReactiveContext::new();
+        let (rc, mut changed) = ReactiveContext::new();
 
         // Create a new signal in that context, wiring up its dependencies and subscribers
         let mut state: Signal<R, S> =
@@ -135,7 +95,8 @@ where
 
         spawn(async move {
             loop {
-                rc.changed().await;
+                // Wait for context to change
+                let _ = changed.next().await;
 
                 let new = rc.run_in(|| f(dependencies_signal.read().clone()));
                 if new != *state.peek() {

+ 8 - 4
packages/hooks/src/use_resource.rs

@@ -6,8 +6,8 @@ use dioxus_core::{
     Task,
 };
 use dioxus_signals::*;
-use futures_util::{future, pin_mut, FutureExt};
-use std::future::Future;
+use futures_util::{future, pin_mut, FutureExt, StreamExt};
+use std::{cell::Cell, future::Future, rc::Rc};
 
 /// A memo that resolve to a value asynchronously.
 /// Unlike `use_future`, `use_resource` runs on the **server**
@@ -44,7 +44,10 @@ where
 {
     let mut value = use_signal(|| None);
     let mut state = use_signal(|| UseResourceState::Pending);
-    let rc = use_hook(ReactiveContext::new);
+    let (rc, changed) = use_hook(|| {
+        let (rc, changed) = ReactiveContext::new();
+        (rc, Rc::new(Cell::new(Some(changed))))
+    });
 
     let mut cb = use_callback(move || {
         // Create the user's task
@@ -70,10 +73,11 @@ where
     let mut task = use_hook(|| Signal::new(cb.call()));
 
     use_hook(|| {
+        let mut changed = changed.take().unwrap();
         spawn(async move {
             loop {
                 // Wait for the dependencies to change
-                rc.changed().await;
+                let _ = changed.next().await;
 
                 // Stop the old task
                 task.write().cancel();

+ 0 - 1
packages/signals/Cargo.toml

@@ -22,7 +22,6 @@ once_cell = "1.18.0"
 rustc-hash = { workspace = true }
 futures-channel = { workspace = true }
 futures-util = { workspace = true }
-flume = { version = "0.11.0", default-features = false, features = ["async"] }
 
 [dev-dependencies]
 dioxus = { workspace = true }

+ 4 - 1
packages/signals/src/copy_value.rs

@@ -237,14 +237,17 @@ impl<T: 'static, S: Storage<T>> Writable for CopyValue<T, S> {
         S::try_map_mut(mut_, f)
     }
 
-    fn try_write(&self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
+    #[track_caller]
+    fn try_write(&mut self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
         self.value.try_write()
     }
 
+    #[track_caller]
     fn write(&mut self) -> Self::Mut<T> {
         self.value.write()
     }
 
+    #[track_caller]
     fn set(&mut self, value: T) {
         self.value.set(value);
     }

+ 9 - 9
packages/signals/src/global/memo.rs

@@ -1,9 +1,9 @@
-use crate::{read::Readable, ReadableRef};
+use crate::{read::Readable, Memo, ReadableRef};
 use dioxus_core::prelude::{IntoAttributeValue, ScopeId};
 use generational_box::UnsyncStorage;
 use std::{mem::MaybeUninit, ops::Deref};
 
-use crate::{ReadOnlySignal, Signal};
+use crate::Signal;
 
 use super::get_global_context;
 
@@ -22,14 +22,14 @@ impl<T: PartialEq + 'static> GlobalMemo<T> {
     }
 
     /// Get the signal that backs this global.
-    pub fn signal(&self) -> ReadOnlySignal<T> {
+    pub fn memo(&self) -> Memo<T> {
         let key = self as *const _ as *const ();
 
         let context = get_global_context();
 
         let read = context.signal.borrow();
         match read.get(&key) {
-            Some(signal) => *signal.downcast_ref::<ReadOnlySignal<T>>().unwrap(),
+            Some(signal) => *signal.downcast_ref::<Memo<T>>().unwrap(),
             None => {
                 drop(read);
                 // Constructors are always run in the root scope
@@ -47,7 +47,7 @@ impl<T: PartialEq + 'static> GlobalMemo<T> {
 
     /// Get the generational id of the signal.
     pub fn id(&self) -> generational_box::GenerationalBoxId {
-        self.signal().id()
+        self.memo().id()
     }
 }
 
@@ -57,12 +57,12 @@ impl<T: PartialEq + 'static> Readable for GlobalMemo<T> {
 
     #[track_caller]
     fn try_read(&self) -> Result<ReadableRef<Self>, generational_box::BorrowError> {
-        self.signal().try_read()
+        self.memo().try_read()
     }
 
     #[track_caller]
     fn peek(&self) -> ReadableRef<Self> {
-        self.signal().peek()
+        self.memo().peek()
     }
 }
 
@@ -71,7 +71,7 @@ where
     T: Clone + IntoAttributeValue,
 {
     fn into_value(self) -> dioxus_core::AttributeValue {
-        self.signal().into_value()
+        self.memo().into_value()
     }
 }
 
@@ -81,7 +81,7 @@ impl<T: PartialEq + 'static> PartialEq for GlobalMemo<T> {
     }
 }
 
-/// Allow calling a signal with signal() syntax
+/// Allow calling a signal with memo() syntax
 ///
 /// Currently only limited to copy types, though could probably specialize for string/arc/rc
 impl<T: PartialEq + Clone + 'static> Deref for GlobalMemo<T> {

+ 1 - 1
packages/signals/src/global/signal.rs

@@ -103,7 +103,7 @@ impl<T: 'static> Writable for GlobalSignal<T> {
     }
 
     #[track_caller]
-    fn try_write(&self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
+    fn try_write(&mut self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
         self.signal().try_write()
     }
 }

+ 11 - 0
packages/signals/src/impls.rs

@@ -1,4 +1,5 @@
 use crate::copy_value::CopyValue;
+use crate::memo::Memo;
 use crate::read::Readable;
 use crate::signal::Signal;
 use crate::write::Writable;
@@ -159,6 +160,16 @@ impl<T: 'static, S: Storage<SignalData<T>>> Clone for ReadOnlySignal<T, S> {
 
 impl<T: 'static, S: Storage<SignalData<T>>> Copy for ReadOnlySignal<T, S> {}
 
+read_impls!(Memo: PartialEq);
+
+impl<T: 'static> Clone for Memo<T> {
+    fn clone(&self) -> Self {
+        *self
+    }
+}
+
+impl<T: 'static> Copy for Memo<T> {}
+
 read_impls!(GlobalSignal);
 default_impl!(GlobalSignal);
 

+ 3 - 0
packages/signals/src/lib.rs

@@ -19,6 +19,9 @@ pub use map::*;
 // mod comparer;
 // pub use comparer::*;
 
+mod memo;
+pub use memo::*;
+
 mod global;
 pub use global::*;
 

+ 220 - 0
packages/signals/src/memo.rs

@@ -0,0 +1,220 @@
+use crate::write::Writable;
+use crate::{read::Readable, ReactiveContext, ReadableRef, Signal};
+use crate::{CopyValue, ReadOnlySignal};
+use std::rc::Rc;
+use std::{
+    cell::RefCell,
+    ops::Deref,
+    panic::Location,
+    sync::{atomic::AtomicBool, Arc},
+};
+
+use dioxus_core::prelude::*;
+use futures_util::StreamExt;
+use generational_box::UnsyncStorage;
+use once_cell::sync::OnceCell;
+
+/// A thread local that can only be read from the thread it was created on.
+pub struct ThreadLocal<T> {
+    value: T,
+    owner: std::thread::ThreadId,
+}
+
+impl<T> ThreadLocal<T> {
+    /// Create a new thread local.
+    pub fn new(value: T) -> Self {
+        ThreadLocal {
+            value,
+            owner: std::thread::current().id(),
+        }
+    }
+
+    /// Get the value of the thread local.
+    pub fn get(&self) -> Option<&T> {
+        (self.owner == std::thread::current().id()).then_some(&self.value)
+    }
+}
+
+// SAFETY: This is safe because the thread local can only be read from the thread it was created on.
+unsafe impl<T> Send for ThreadLocal<T> {}
+unsafe impl<T> Sync for ThreadLocal<T> {}
+
+struct UpdateInformation<T> {
+    dirty: Arc<AtomicBool>,
+    callback: RefCell<Box<dyn FnMut() -> T>>,
+}
+
+/// A value that is memoized. This is useful for caching the result of a computation.
+pub struct Memo<T: 'static> {
+    inner: Signal<T>,
+    update: CopyValue<UpdateInformation<T>>,
+}
+
+impl<T> From<Memo<T>> for ReadOnlySignal<T>
+where
+    T: PartialEq,
+{
+    fn from(val: Memo<T>) -> Self {
+        ReadOnlySignal::new(val.inner)
+    }
+}
+
+impl<T: 'static> Memo<T> {
+    /// Create a new memo
+    #[track_caller]
+    pub fn new(mut f: impl FnMut() -> T + 'static) -> Self
+    where
+        T: PartialEq,
+    {
+        let dirty = Arc::new(AtomicBool::new(true));
+        let (tx, mut rx) = futures_channel::mpsc::unbounded();
+
+        let myself: Rc<OnceCell<Memo<T>>> = Rc::new(OnceCell::new());
+        let thread_local = ThreadLocal::new(myself.clone());
+
+        let callback = {
+            let dirty = dirty.clone();
+            move || match thread_local.get() {
+                Some(memo) => match memo.get() {
+                    Some(memo) => {
+                        memo.recompute();
+                    }
+                    None => {
+                        tracing::error!("Memo was not initialized in the same thread it was created in. This is likely a bug in dioxus");
+                        dirty.store(true, std::sync::atomic::Ordering::Relaxed);
+                        let _ = tx.unbounded_send(());
+                    }
+                },
+                None => {
+                    dirty.store(true, std::sync::atomic::Ordering::Relaxed);
+                    let _ = tx.unbounded_send(());
+                }
+            }
+        };
+        let rc = ReactiveContext::new_with_callback(
+            callback,
+            current_scope_id().unwrap(),
+            Location::caller(),
+        );
+
+        // Create a new signal in that context, wiring up its dependencies and subscribers
+        let mut recompute = move || rc.run_in(&mut f);
+        let value = recompute();
+        let recompute = RefCell::new(Box::new(recompute) as Box<dyn FnMut() -> T>);
+        let update = CopyValue::new(UpdateInformation {
+            dirty,
+            callback: recompute,
+        });
+        let state: Signal<T> = Signal::new(value);
+
+        let memo = Memo {
+            inner: state,
+            update,
+        };
+        let _ = myself.set(memo);
+
+        spawn(async move {
+            while rx.next().await.is_some() {
+                // Remove any pending updates
+                while rx.try_next().is_ok() {}
+                memo.recompute();
+            }
+        });
+
+        memo
+    }
+
+    /// Rerun the computation and update the value of the memo if the result has changed.
+    #[tracing::instrument(skip(self))]
+    fn recompute(&self)
+    where
+        T: PartialEq,
+    {
+        let mut update_copy = self.update;
+        let update_write = update_copy.write();
+        let peak = self.inner.peek();
+        let new_value = (update_write.callback.borrow_mut())();
+        if new_value != *peak {
+            drop(peak);
+            let mut copy = self.inner;
+            copy.set(new_value);
+            update_write
+                .dirty
+                .store(false, std::sync::atomic::Ordering::Relaxed);
+        }
+    }
+
+    /// Get the scope that the signal was created in.
+    pub fn origin_scope(&self) -> ScopeId {
+        self.inner.origin_scope()
+    }
+
+    /// Get the id of the signal.
+    pub fn id(&self) -> generational_box::GenerationalBoxId {
+        self.inner.id()
+    }
+}
+
+impl<T> Readable for Memo<T>
+where
+    T: PartialEq,
+{
+    type Target = T;
+    type Storage = UnsyncStorage;
+
+    #[track_caller]
+    fn try_read(&self) -> Result<ReadableRef<Self>, generational_box::BorrowError> {
+        let read = self.inner.try_read();
+        match read {
+            Ok(r) => {
+                let needs_update = self
+                    .update
+                    .read()
+                    .dirty
+                    .swap(false, std::sync::atomic::Ordering::Relaxed);
+                if needs_update {
+                    drop(r);
+                    self.recompute();
+                    self.inner.try_read()
+                } else {
+                    Ok(r)
+                }
+            }
+            Err(e) => Err(e),
+        }
+    }
+
+    /// Get the current value of the signal. **Unlike read, this will not subscribe the current scope to the signal which can cause parts of your UI to not update.**
+    ///
+    /// If the signal has been dropped, this will panic.
+    #[track_caller]
+    fn peek(&self) -> ReadableRef<Self> {
+        self.inner.peek()
+    }
+}
+
+impl<T> IntoAttributeValue for Memo<T>
+where
+    T: Clone + IntoAttributeValue + PartialEq,
+{
+    fn into_value(self) -> dioxus_core::AttributeValue {
+        self.with(|f| f.clone().into_value())
+    }
+}
+
+impl<T: 'static> PartialEq for Memo<T> {
+    fn eq(&self, other: &Self) -> bool {
+        self.inner == other.inner
+    }
+}
+
+impl<T: Clone> Deref for Memo<T>
+where
+    T: PartialEq,
+{
+    type Target = dyn Fn() -> T;
+
+    fn deref(&self) -> &Self::Target {
+        Readable::deref_impl(self)
+    }
+}

+ 41 - 67
packages/signals/src/reactive_context.rs

@@ -1,9 +1,9 @@
 use dioxus_core::prelude::{
     current_scope_id, has_context, provide_context, schedule_update_any, ScopeId,
 };
+use futures_channel::mpsc::UnboundedReceiver;
 use generational_box::SyncStorage;
-use rustc_hash::FxHashSet;
-use std::{cell::RefCell, hash::Hash, sync::Arc};
+use std::{cell::RefCell, hash::Hash};
 
 use crate::{CopyValue, Readable, Writable};
 
@@ -24,67 +24,51 @@ thread_local! {
 
 impl std::fmt::Display for ReactiveContext {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let read = self.inner.read();
-        match read.scope_subscriber {
-            Some(scope) => write!(f, "ReactiveContext for scope {:?}", scope),
-            None => {
-                #[cfg(debug_assertions)]
+        #[cfg(debug_assertions)]
+        {
+            if let Ok(read) = self.inner.try_read() {
                 return write!(f, "ReactiveContext created at {}", read.origin);
-                #[cfg(not(debug_assertions))]
-                write!(f, "ReactiveContext")
             }
         }
-    }
-}
-
-impl Default for ReactiveContext {
-    #[track_caller]
-    fn default() -> Self {
-        Self::new_for_scope(None, std::panic::Location::caller())
+        write!(f, "ReactiveContext")
     }
 }
 
 impl ReactiveContext {
     /// Create a new reactive context
     #[track_caller]
-    pub fn new() -> Self {
-        Self::default()
+    pub fn new() -> (Self, UnboundedReceiver<()>) {
+        Self::new_with_origin(std::panic::Location::caller())
     }
 
     /// Create a new reactive context with a location for debugging purposes
     /// This is useful for reactive contexts created within closures
-    pub fn new_with_origin(origin: &'static std::panic::Location<'static>) -> Self {
-        Self::new_for_scope(None, origin)
+    pub fn new_with_origin(
+        origin: &'static std::panic::Location<'static>,
+    ) -> (Self, UnboundedReceiver<()>) {
+        let (tx, rx) = futures_channel::mpsc::unbounded();
+        let callback = move || {
+            let _ = tx.unbounded_send(());
+        };
+        let _self = Self::new_with_callback(callback, current_scope_id().unwrap(), origin);
+        (_self, rx)
     }
 
-    /// Create a new reactive context that may update a scope
-    #[allow(unused)]
-    pub(crate) fn new_for_scope(
-        scope: Option<ScopeId>,
+    /// Create a new reactive context that may update a scope. When any signal that this context subscribes to changes, the callback will be run
+    pub fn new_with_callback(
+        callback: impl FnMut() + Send + Sync + 'static,
+        scope: ScopeId,
         origin: &'static std::panic::Location<'static>,
     ) -> Self {
-        let (tx, rx) = flume::unbounded();
-
-        let mut scope_subscribers = FxHashSet::default();
-        if let Some(scope) = scope {
-            scope_subscribers.insert(scope);
-        }
-
         let inner = Inner {
-            scope_subscriber: scope,
-            sender: tx,
             self_: None,
-            update_any: schedule_update_any(),
-            receiver: rx,
+            update: Box::new(callback),
             #[cfg(debug_assertions)]
             origin,
         };
 
         let mut self_ = Self {
-            inner: CopyValue::new_maybe_sync_in_scope(
-                inner,
-                scope.or_else(current_scope_id).unwrap(),
-            ),
+            inner: CopyValue::new_maybe_sync_in_scope(inner, scope),
         };
 
         self_.inner.write().self_ = Some(self_);
@@ -112,10 +96,17 @@ impl ReactiveContext {
         if let Some(cx) = has_context() {
             return Some(cx);
         }
+        let update_any = schedule_update_any();
+        let scope_id = current_scope_id().unwrap();
+        let update_scope = move || {
+            tracing::trace!("Marking scope {:?} as dirty", scope_id);
+            update_any(scope_id)
+        };
 
         // Otherwise, create a new context at the current scope
-        Some(provide_context(ReactiveContext::new_for_scope(
-            current_scope_id(),
+        Some(provide_context(ReactiveContext::new_with_callback(
+            update_scope,
+            scope_id,
             std::panic::Location::caller(),
         )))
     }
@@ -137,25 +128,18 @@ impl ReactiveContext {
     ///
     /// Returns true if the context was marked as dirty, or false if the context has been dropped
     pub fn mark_dirty(&self) -> bool {
-        if let Ok(self_read) = self.inner.try_read() {
+        let mut copy = self.inner;
+        if let Ok(mut self_write) = copy.try_write() {
             #[cfg(debug_assertions)]
             {
-                if let Some(scope) = self_read.scope_subscriber {
-                    tracing::trace!("Marking reactive context for scope {:?} as dirty", scope);
-                } else {
-                    tracing::trace!(
-                        "Marking reactive context created at {} as dirty",
-                        self_read.origin
-                    );
-                }
-            }
-            if let Some(scope) = self_read.scope_subscriber {
-                (self_read.update_any)(scope);
+                tracing::trace!(
+                    "Marking reactive context created at {} as dirty",
+                    self_write.origin
+                );
             }
 
-            // mark the listeners as dirty
-            // If the channel is full it means that the receivers have already been marked as dirty
-            _ = self_read.sender.try_send(());
+            (self_write.update)();
+
             true
         } else {
             false
@@ -166,12 +150,6 @@ impl ReactiveContext {
     pub fn origin_scope(&self) -> ScopeId {
         self.inner.origin_scope()
     }
-
-    /// Wait for this reactive context to change
-    pub async fn changed(&self) {
-        let rx = self.inner.read().receiver.clone();
-        _ = rx.recv_async().await;
-    }
 }
 
 impl Hash for ReactiveContext {
@@ -181,14 +159,10 @@ impl Hash for ReactiveContext {
 }
 
 struct Inner {
-    // A scope we mark as dirty when this context is written to
-    scope_subscriber: Option<ScopeId>,
     self_: Option<ReactiveContext>,
-    update_any: Arc<dyn Fn(ScopeId) + Send + Sync>,
 
     // Futures will call .changed().await
-    sender: flume::Sender<()>,
-    receiver: flume::Receiver<()>,
+    update: Box<dyn FnMut() + Send + Sync>,
 
     // Debug information for signal subscriptions
     #[cfg(debug_assertions)]

+ 10 - 37
packages/signals/src/signal.rs

@@ -1,11 +1,9 @@
+use crate::Memo;
 use crate::{
     read::Readable, write::Writable, CopyValue, GlobalMemo, GlobalSignal, ReactiveContext,
-    ReadOnlySignal, ReadableRef,
-};
-use dioxus_core::{
-    prelude::{spawn, IntoAttributeValue},
-    ScopeId,
+    ReadableRef,
 };
+use dioxus_core::{prelude::IntoAttributeValue, ScopeId};
 use generational_box::{AnyStorage, Storage, SyncStorage, UnsyncStorage};
 use std::{
     any::Any,
@@ -88,35 +86,8 @@ impl<T: PartialEq + 'static> Signal<T> {
     ///
     /// Selectors can be used to efficiently compute derived data from signals.
     #[track_caller]
-    pub fn memo(f: impl FnMut() -> T + 'static) -> ReadOnlySignal<T> {
-        Self::use_maybe_sync_memo(f)
-    }
-
-    /// Creates a new Selector that may be Sync + Send. The selector will be run immediately and whenever any signal it reads changes.
-    ///
-    /// Selectors can be used to efficiently compute derived data from signals.
-    #[track_caller]
-    pub fn use_maybe_sync_memo<S: Storage<SignalData<T>>>(
-        mut f: impl FnMut() -> T + 'static,
-    ) -> ReadOnlySignal<T, S> {
-        // Get the current reactive context
-        let rc = ReactiveContext::new();
-
-        // Create a new signal in that context, wiring up its dependencies and subscribers
-        let mut state: Signal<T, S> = rc.run_in(|| Signal::new_maybe_sync(f()));
-
-        spawn(async move {
-            loop {
-                rc.changed().await;
-                let new = f();
-                if new != *state.peek() {
-                    *state.write() = new;
-                }
-            }
-        });
-
-        // And just return the readonly variant of that signal
-        ReadOnlySignal::new_maybe_sync(state)
+    pub fn memo(f: impl FnMut() -> T + 'static) -> Memo<T> {
+        Memo::new(f)
     }
 }
 
@@ -179,8 +150,10 @@ impl<T: 'static, S: Storage<SignalData<T>>> Signal<T, S> {
         {
             let inner = self.inner.read();
 
-            let mut subscribers = inner.subscribers.lock().unwrap();
-            subscribers.retain(|reactive_context| reactive_context.mark_dirty())
+            // We cannot hold the subscribers lock while calling mark_dirty, because mark_dirty can run user code which may cause a new subscriber to be added. If we hold the lock, we will deadlock.
+            let mut subscribers = std::mem::take(&mut *inner.subscribers.lock().unwrap());
+            subscribers.retain(|reactive_context| reactive_context.mark_dirty());
+            *inner.subscribers.lock().unwrap() = subscribers;
         }
     }
 
@@ -237,7 +210,7 @@ impl<T: 'static, S: Storage<SignalData<T>>> Writable for Signal<T, S> {
     }
 
     #[track_caller]
-    fn try_write(&self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
+    fn try_write(&mut self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
         self.inner.try_write().map(|inner| {
             let borrow = S::map_mut(inner, |v| &mut v.value);
             Write {

+ 1 - 1
packages/signals/src/write.rs

@@ -27,7 +27,7 @@ pub trait Writable: Readable {
     }
 
     /// Try to get a mutable reference to the value. If the value has been dropped, this will panic.
-    fn try_write(&self) -> Result<Self::Mut<Self::Target>, generational_box::BorrowMutError>;
+    fn try_write(&mut self) -> Result<Self::Mut<Self::Target>, generational_box::BorrowMutError>;
 
     /// Run a function with a mutable reference to the value. If the value has been dropped, this will panic.
     #[track_caller]