Browse Source

rerun tasks in the same order as components

Evan Almloff 1 năm trước cách đây
mục cha
commit
c7ffdc7b29

+ 3 - 2
packages/core/src/arena.rs

@@ -1,4 +1,5 @@
-use crate::{innerlude::DirtyScope, virtual_dom::VirtualDom, ScopeId};
+use crate::innerlude::ScopeOrder;
+use crate::{virtual_dom::VirtualDom, ScopeId};
 
 /// An Element's unique identifier.
 ///
@@ -74,7 +75,7 @@ impl VirtualDom {
             context.height
         };
 
-        self.dirty_scopes.remove(&DirtyScope { height, id });
+        self.dirty_scopes.remove(&ScopeOrder::new(height, id));
     }
 }
 

+ 1 - 4
packages/core/src/diff/component.rs

@@ -91,10 +91,7 @@ impl VNode {
         dom.diff_scope(to, scope_id, new);
 
         let height = dom.runtime.get_state(scope_id).unwrap().height;
-        dom.dirty_scopes.remove(&DirtyScope {
-            height,
-            id: scope_id,
-        });
+        dom.dirty_scopes.remove(&DirtyScope::new(height, scope_id));
     }
 
     fn replace_vcomponent(

+ 77 - 7
packages/core/src/dirty_scope.rs

@@ -1,33 +1,103 @@
+use crate::ScopeId;
+use crate::Task;
+use std::borrow::Borrow;
+use std::cell::Cell;
+use std::cell::RefCell;
 use std::hash::Hash;
 
-use crate::ScopeId;
+#[derive(Debug, Clone, Eq)]
+pub struct ScopeOrder {
+    pub(crate) height: u32,
+    pub(crate) id: ScopeId,
+}
+
+impl ScopeOrder {
+    pub fn new(height: u32, id: ScopeId) -> Self {
+        Self { height, id }
+    }
+}
+
+impl PartialEq for ScopeOrder {
+    fn eq(&self, other: &Self) -> bool {
+        self.id == other.id
+    }
+}
+
+impl PartialOrd for ScopeOrder {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl Ord for ScopeOrder {
+    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
+        self.height.cmp(&other.height).then(self.id.cmp(&other.id))
+    }
+}
+
+impl Hash for ScopeOrder {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.id.hash(state);
+    }
+}
 
 #[derive(Debug, Clone, Eq)]
 pub struct DirtyScope {
-    pub height: u32,
-    pub id: ScopeId,
+    pub order: ScopeOrder,
+    pub rerun_queued: Cell<bool>,
+    pub tasks_queued: RefCell<Vec<Task>>,
+}
+
+impl From<ScopeOrder> for DirtyScope {
+    fn from(order: ScopeOrder) -> Self {
+        Self {
+            order,
+            rerun_queued: false.into(),
+            tasks_queued: Vec::new().into(),
+        }
+    }
+}
+
+impl DirtyScope {
+    pub fn new(height: u32, id: ScopeId) -> Self {
+        ScopeOrder { height, id }.into()
+    }
+
+    pub fn queue_task(&self, task: Task) {
+        self.tasks_queued.borrow_mut().push(task);
+    }
+
+    pub fn queue_rerun(&self) {
+        self.rerun_queued.set(true);
+    }
+}
+
+impl Borrow<ScopeOrder> for DirtyScope {
+    fn borrow(&self) -> &ScopeOrder {
+        &self.order
+    }
 }
 
 impl PartialOrd for DirtyScope {
     fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
-        Some(self.cmp(other))
+        Some(self.order.cmp(&other.order))
     }
 }
 
 impl Ord for DirtyScope {
     fn cmp(&self, other: &Self) -> std::cmp::Ordering {
-        self.height.cmp(&other.height).then(self.id.cmp(&other.id))
+        self.order.cmp(&other.order)
     }
 }
 
 impl PartialEq for DirtyScope {
     fn eq(&self, other: &Self) -> bool {
-        self.id == other.id
+        self.order == other.order
     }
 }
 
 impl Hash for DirtyScope {
     fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
-        self.id.hash(state);
+        self.order.hash(state);
     }
 }

+ 4 - 5
packages/core/src/scope_arena.rs

@@ -1,6 +1,7 @@
+use crate::innerlude::ScopeOrder;
 use crate::{
     any_props::{AnyProps, BoxedAnyProps},
-    innerlude::{DirtyScope, ScopeState},
+    innerlude::ScopeState,
     nodes::RenderReturn,
     scope_context::Scope,
     scopes::ScopeId,
@@ -66,10 +67,8 @@ impl VirtualDom {
         context.render_count.set(context.render_count.get() + 1);
 
         // remove this scope from dirty scopes
-        self.dirty_scopes.remove(&DirtyScope {
-            height: context.height,
-            id: context.id,
-        });
+        self.dirty_scopes
+            .remove(&ScopeOrder::new(context.height, scope_id));
 
         if context.suspended.get() {
             if matches!(new_nodes, RenderReturn::Aborted(_)) {

+ 4 - 0
packages/core/src/tasks.rs

@@ -135,6 +135,10 @@ impl Runtime {
         self.tasks.borrow().get(task.0)?.parent
     }
 
+    pub(crate) fn task_scope(&self, task: Task) -> Option<ScopeId> {
+        self.tasks.borrow().get(task.0).map(|t| t.scope)
+    }
+
     pub(crate) fn handle_task_wakeup(&self, id: Task) -> Poll<()> {
         debug_assert!(Runtime::current().is_some(), "Must be in a dioxus runtime");
 

+ 78 - 15
packages/core/src/virtual_dom.rs

@@ -2,6 +2,8 @@
 //!
 //! This module provides the primary mechanics to create a hook-based, concurrent VDOM for Rust.
 
+use crate::innerlude::ScopeOrder;
+use crate::Task;
 use crate::{
     any_props::AnyProps,
     arena::ElementId,
@@ -183,6 +185,7 @@ pub struct VirtualDom {
     pub(crate) scopes: Slab<ScopeState>,
 
     pub(crate) dirty_scopes: BTreeSet<DirtyScope>,
+    pub(crate) scopes_need_rerun: bool,
 
     // Maps a template path to a map of byte indexes to templates
     pub(crate) templates: FxHashMap<TemplateId, FxHashMap<usize, Template>>,
@@ -310,6 +313,7 @@ impl VirtualDom {
             rx,
             runtime: Runtime::new(tx),
             scopes: Default::default(),
+            scopes_need_rerun: false,
             dirty_scopes: Default::default(),
             templates: Default::default(),
             queued_templates: Default::default(),
@@ -374,10 +378,39 @@ impl VirtualDom {
         };
 
         tracing::trace!("Marking scope {:?} as dirty", id);
-        self.dirty_scopes.insert(DirtyScope {
-            height: scope.height(),
-            id,
-        });
+        self.scopes_need_rerun = true;
+        let order = ScopeOrder::new(scope.height(), id);
+        match self.dirty_scopes.get(&order) {
+            Some(dirty) => {
+                dirty.queue_rerun();
+            }
+            None => {
+                let dirty: DirtyScope = order.into();
+                dirty.queue_rerun();
+                self.dirty_scopes.insert(dirty);
+            }
+        }
+    }
+
+    /// Mark a task as dirty
+    fn mark_task_dirty(&mut self, task: Task) {
+        let Some(scope) = self.runtime.task_scope(task) else {
+            return;
+        };
+        let Some(scope) = self.runtime.get_state(scope) else {
+            return;
+        };
+        let order = ScopeOrder::new(scope.height(), scope.id);
+        match self.dirty_scopes.get(&order) {
+            Some(dirty) => {
+                dirty.queue_task(task);
+            }
+            None => {
+                let dirty: DirtyScope = order.into();
+                dirty.queue_task(task);
+                self.dirty_scopes.insert(dirty);
+            }
+        }
     }
 
     /// Call a listener inside the VirtualDom with data from outside the VirtualDom. **The ElementId passed in must be the id of an element with a listener, not a static node or a text node.**
@@ -439,13 +472,32 @@ impl VirtualDom {
             self.process_events();
 
             // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
-            if !self.dirty_scopes.is_empty() {
+            if self.scopes_need_rerun {
                 return;
             }
 
             // Make sure we set the runtime since we're running user code
             let _runtime = RuntimeGuard::new(self.runtime.clone());
 
+            // Next, run any queued tasks
+            // We choose not to poll the deadline since we complete pretty quickly anyways
+            while let Some(dirty) = self.dirty_scopes.pop_first() {
+                // If the scope doesn't exist for whatever reason, then we should skip it
+                if !self.scopes.contains(dirty.order.id.0) {
+                    continue;
+                }
+
+                // Then poll any tasks that might be pending
+                for task in dirty.tasks_queued.borrow().iter() {
+                    let _ = self.runtime.handle_task_wakeup(*task);
+                    // Running that task, may mark a scope higher up as dirty. If it does, return from the function early
+                    self.process_events();
+                    if self.scopes_need_rerun {
+                        return;
+                    }
+                }
+            }
+
             // Hold a lock to the flush sync to prevent tasks from running in the event we get an immediate
             // When we're doing awaiting the rx, the lock will be dropped and tasks waiting on the lock will get waked
             // We have to own the lock since poll_tasks is cancel safe - the future that this is running in might get dropped
@@ -455,7 +507,12 @@ impl VirtualDom {
 
             match self.rx.next().await.expect("channel should never close") {
                 SchedulerMsg::Immediate(id) => self.mark_dirty(id),
-                SchedulerMsg::TaskNotified(id) => _ = self.runtime.handle_task_wakeup(id),
+                SchedulerMsg::TaskNotified(id) => {
+                    // _ = self.runtime.handle_task_wakeup(id)
+                    // Instead of running the task immediately, we insert it into the runtime's task queue.
+                    // The task may be marked dirty at the same time as the scope that owns the task is dropped.
+                    self.mark_task_dirty(id);
+                }
             };
         }
     }
@@ -468,7 +525,7 @@ impl VirtualDom {
         while let Ok(Some(msg)) = self.rx.try_next() {
             match msg {
                 SchedulerMsg::Immediate(id) => self.mark_dirty(id),
-                SchedulerMsg::TaskNotified(task) => _ = self.runtime.handle_task_wakeup(task),
+                SchedulerMsg::TaskNotified(task) => self.mark_task_dirty(task),
             }
         }
     }
@@ -489,10 +546,8 @@ impl VirtualDom {
                 {
                     let context = scope.state();
                     let height = context.height;
-                    self.dirty_scopes.insert(DirtyScope {
-                        height,
-                        id: context.id,
-                    });
+                    self.dirty_scopes
+                        .insert(DirtyScope::new(height, context.id));
                 }
             }
         }
@@ -558,18 +613,26 @@ impl VirtualDom {
         // We choose not to poll the deadline since we complete pretty quickly anyways
         while let Some(dirty) = self.dirty_scopes.pop_first() {
             // If the scope doesn't exist for whatever reason, then we should skip it
-            if !self.scopes.contains(dirty.id.0) {
+            if !self.scopes.contains(dirty.order.id.0) {
                 continue;
             }
 
             {
                 let _runtime = RuntimeGuard::new(self.runtime.clone());
-                // Run the scope and get the mutations
-                let new_nodes = self.run_scope(dirty.id);
+                // Poll any tasks that might be pending in the scope
+                for task in dirty.tasks_queued.borrow().iter() {
+                    let _ = self.runtime.handle_task_wakeup(*task);
+                }
+                // If the scope is dirty, run the scope and get the mutations
+                if dirty.rerun_queued.get() {
+                    let new_nodes = self.run_scope(dirty.order.id);
 
-                self.diff_scope(to, dirty.id, new_nodes);
+                    self.diff_scope(to, dirty.order.id, new_nodes);
+                }
             }
         }
+
+        self.scopes_need_rerun = false;
     }
 
     /// [`Self::render_immediate`] to a vector of mutations for testing purposes

+ 34 - 7
packages/core/tests/task.rs

@@ -89,14 +89,19 @@ async fn yield_now_works() {
 async fn flushing() {
     thread_local! {
         static SEQUENCE: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
+        static BROADCAST: (tokio::sync::broadcast::Sender<()>, tokio::sync::broadcast::Receiver<()>) = tokio::sync::broadcast::channel(1);
     }
 
     fn app() -> Element {
+        if generation() > 0 {
+            SEQUENCE.with(|s| s.borrow_mut().push(0));
+        }
         use_hook(|| {
             spawn(async move {
                 for _ in 0..10 {
                     flush_sync().await;
                     SEQUENCE.with(|s| s.borrow_mut().push(1));
+                    BROADCAST.with(|b| b.1.resubscribe()).recv().await.unwrap();
                 }
             })
         });
@@ -106,11 +111,12 @@ async fn flushing() {
                 for _ in 0..10 {
                     flush_sync().await;
                     SEQUENCE.with(|s| s.borrow_mut().push(2));
+                    BROADCAST.with(|b| b.1.resubscribe()).recv().await.unwrap();
                 }
             })
         });
 
-        rsx!({})
+        rsx! {}
     }
 
     let mut dom = VirtualDom::new(app);
@@ -119,11 +125,11 @@ async fn flushing() {
 
     let fut = async {
         // Trigger the flush by waiting for work
-        for _ in 0..40 {
-            tokio::select! {
-                _ = dom.wait_for_work() => {}
-                _ = tokio::time::sleep(Duration::from_millis(1)) => {}
-            };
+        for _ in 0..10 {
+            dom.mark_dirty(ScopeId(0));
+            BROADCAST.with(|b| b.0.send(()).unwrap());
+            dom.wait_for_work().await;
+            dom.render_immediate(&mut dioxus_core::NoOpMutations);
         }
     };
 
@@ -132,5 +138,26 @@ async fn flushing() {
         _ = tokio::time::sleep(Duration::from_millis(500)) => {}
     };
 
-    SEQUENCE.with(|s| assert_eq!(s.borrow().len(), 20));
+    SEQUENCE.with(|s| {
+        let s = s.borrow();
+        println!("{:?}", s);
+        assert_eq!(s.len(), 30);
+        // We need to check if every three elements look like [0, 1, 2] or [0, 2, 1]
+        let mut has_seen_1 = false;
+        for (i, &x) in s.iter().enumerate() {
+            let stage = i % 3;
+            if stage == 0 {
+                assert_eq!(x, 0);
+            } else if stage == 1 {
+                assert!(x == 1 || x == 2);
+                has_seen_1 = x == 1;
+            } else if stage == 2 {
+                if has_seen_1 {
+                    assert_eq!(x, 2);
+                } else {
+                    assert_eq!(x, 1);
+                }
+            }
+        }
+    });
 }

+ 1 - 3
packages/hooks/src/use_memo.rs

@@ -48,7 +48,7 @@ pub fn use_maybe_sync_memo<R: PartialEq, S: Storage<SignalData<R>>>(
     mut f: impl FnMut() -> R + 'static,
 ) -> ReadOnlySignal<R, S> {
     use_hook(|| {
-        // Get the current reactive context
+        // Create a new reactive context for the memo
         let rc = ReactiveContext::new();
 
         // Create a new signal in that context, wiring up its dependencies and subscribers
@@ -56,8 +56,6 @@ pub fn use_maybe_sync_memo<R: PartialEq, S: Storage<SignalData<R>>>(
 
         spawn(async move {
             loop {
-                // Wait for the dom the be finished with sync work
-                flush_sync().await;
                 rc.changed().await;
                 let new = rc.run_in(&mut f);
                 if new != *state.peek() {

+ 146 - 148
packages/signals/tests/memo.rs

@@ -1,148 +1,146 @@
-// TODO: fix #1935
-
-// #![allow(unused, non_upper_case_globals, non_snake_case)]
-// use dioxus_core::NoOpMutations;
-// use std::collections::HashMap;
-// use std::rc::Rc;
-
-// use dioxus::html::p;
-// use dioxus::prelude::*;
-// use dioxus_core::ElementId;
-// use dioxus_signals::*;
-// use std::cell::RefCell;
-
-// #[test]
-// fn memos_rerun() {
-//     let _ = simple_logger::SimpleLogger::new().init();
-
-//     #[derive(Default)]
-//     struct RunCounter {
-//         component: usize,
-//         effect: usize,
-//     }
-
-//     let counter = Rc::new(RefCell::new(RunCounter::default()));
-//     let mut dom = VirtualDom::new_with_props(
-//         |counter: Rc<RefCell<RunCounter>>| {
-//             counter.borrow_mut().component += 1;
-
-//             let mut signal = use_signal(|| 0);
-//             let memo = use_memo({
-//                 to_owned![counter];
-//                 move || {
-//                     counter.borrow_mut().effect += 1;
-//                     println!("Signal: {:?}", signal);
-//                     signal()
-//                 }
-//             });
-//             assert_eq!(memo(), 0);
-//             signal += 1;
-//             assert_eq!(memo(), 1);
-
-//             rsx! {
-//                 div {}
-//             }
-//         },
-//         counter.clone(),
-//     );
-
-//     dom.rebuild_in_place();
-
-//     let current_counter = counter.borrow();
-//     assert_eq!(current_counter.component, 1);
-//     assert_eq!(current_counter.effect, 2);
-// }
-
-// #[test]
-// fn memos_prevents_component_rerun() {
-//     let _ = simple_logger::SimpleLogger::new().init();
-
-//     #[derive(Default)]
-//     struct RunCounter {
-//         component: usize,
-//         memo: usize,
-//     }
-
-//     let counter = Rc::new(RefCell::new(RunCounter::default()));
-//     let mut dom = VirtualDom::new_with_props(
-//         |props: Rc<RefCell<RunCounter>>| {
-//             let mut signal = use_signal(|| 0);
-
-//             if generation() == 1 {
-//                 *signal.write() = 0;
-//             }
-//             if generation() == 2 {
-//                 println!("Writing to signal");
-//                 *signal.write() = 1;
-//             }
-
-//             rsx! {
-//                 Child {
-//                     signal: signal,
-//                     counter: props.clone(),
-//                 }
-//             }
-//         },
-//         counter.clone(),
-//     );
-
-//     #[derive(Default, Props, Clone)]
-//     struct ChildProps {
-//         signal: Signal<usize>,
-//         counter: Rc<RefCell<RunCounter>>,
-//     }
-
-//     impl PartialEq for ChildProps {
-//         fn eq(&self, other: &Self) -> bool {
-//             self.signal == other.signal
-//         }
-//     }
-
-//     fn Child(props: ChildProps) -> Element {
-//         let counter = &props.counter;
-//         let signal = props.signal;
-//         counter.borrow_mut().component += 1;
-
-//         let memo = use_memo({
-//             to_owned![counter];
-//             move || {
-//                 counter.borrow_mut().memo += 1;
-//                 println!("Signal: {:?}", signal);
-//                 signal()
-//             }
-//         });
-//         match generation() {
-//             0 => {
-//                 assert_eq!(memo(), 0);
-//             }
-//             1 => {
-//                 assert_eq!(memo(), 1);
-//             }
-//             _ => panic!("Unexpected generation"),
-//         }
-
-//         rsx! {
-//             div {}
-//         }
-//     }
-
-//     dom.rebuild_in_place();
-//     dom.mark_dirty(ScopeId::ROOT);
-//     dom.render_immediate(&mut NoOpMutations);
-
-//     {
-//         let current_counter = counter.borrow();
-//         assert_eq!(current_counter.component, 1);
-//         assert_eq!(current_counter.memo, 2);
-//     }
-
-//     dom.mark_dirty(ScopeId::ROOT);
-//     dom.render_immediate(&mut NoOpMutations);
-//     dom.render_immediate(&mut NoOpMutations);
-
-//     {
-//         let current_counter = counter.borrow();
-//         assert_eq!(current_counter.component, 2);
-//         assert_eq!(current_counter.memo, 3);
-//     }
-// }
+#![allow(unused, non_upper_case_globals, non_snake_case)]
+use dioxus_core::NoOpMutations;
+use std::collections::HashMap;
+use std::rc::Rc;
+
+use dioxus::html::p;
+use dioxus::prelude::*;
+use dioxus_core::ElementId;
+use dioxus_signals::*;
+use std::cell::RefCell;
+
+#[test]
+fn memos_rerun() {
+    let _ = simple_logger::SimpleLogger::new().init();
+
+    #[derive(Default)]
+    struct RunCounter {
+        component: usize,
+        effect: usize,
+    }
+
+    let counter = Rc::new(RefCell::new(RunCounter::default()));
+    let mut dom = VirtualDom::new_with_props(
+        |counter: Rc<RefCell<RunCounter>>| {
+            counter.borrow_mut().component += 1;
+
+            let mut signal = use_signal(|| 0);
+            let memo = use_memo({
+                to_owned![counter];
+                move || {
+                    counter.borrow_mut().effect += 1;
+                    println!("Signal: {:?}", signal);
+                    signal()
+                }
+            });
+            assert_eq!(memo(), 0);
+            signal += 1;
+            assert_eq!(memo(), 1);
+
+            rsx! {
+                div {}
+            }
+        },
+        counter.clone(),
+    );
+
+    dom.rebuild_in_place();
+
+    let current_counter = counter.borrow();
+    assert_eq!(current_counter.component, 1);
+    assert_eq!(current_counter.effect, 2);
+}
+
+#[test]
+fn memos_prevents_component_rerun() {
+    let _ = simple_logger::SimpleLogger::new().init();
+
+    #[derive(Default)]
+    struct RunCounter {
+        component: usize,
+        memo: usize,
+    }
+
+    let counter = Rc::new(RefCell::new(RunCounter::default()));
+    let mut dom = VirtualDom::new_with_props(
+        |props: Rc<RefCell<RunCounter>>| {
+            let mut signal = use_signal(|| 0);
+
+            if generation() == 1 {
+                *signal.write() = 0;
+            }
+            if generation() == 2 {
+                println!("Writing to signal");
+                *signal.write() = 1;
+            }
+
+            rsx! {
+                Child {
+                    signal: signal,
+                    counter: props.clone(),
+                }
+            }
+        },
+        counter.clone(),
+    );
+
+    #[derive(Default, Props, Clone)]
+    struct ChildProps {
+        signal: Signal<usize>,
+        counter: Rc<RefCell<RunCounter>>,
+    }
+
+    impl PartialEq for ChildProps {
+        fn eq(&self, other: &Self) -> bool {
+            self.signal == other.signal
+        }
+    }
+
+    fn Child(props: ChildProps) -> Element {
+        let counter = &props.counter;
+        let signal = props.signal;
+        counter.borrow_mut().component += 1;
+
+        let memo = use_memo({
+            to_owned![counter];
+            move || {
+                counter.borrow_mut().memo += 1;
+                println!("Signal: {:?}", signal);
+                signal()
+            }
+        });
+        match generation() {
+            0 => {
+                assert_eq!(memo(), 0);
+            }
+            1 => {
+                assert_eq!(memo(), 1);
+            }
+            _ => panic!("Unexpected generation"),
+        }
+
+        rsx! {
+            div {}
+        }
+    }
+
+    dom.rebuild_in_place();
+    dom.mark_dirty(ScopeId::ROOT);
+    dom.render_immediate(&mut NoOpMutations);
+
+    {
+        let current_counter = counter.borrow();
+        assert_eq!(current_counter.component, 1);
+        assert_eq!(current_counter.memo, 2);
+    }
+
+    dom.mark_dirty(ScopeId::ROOT);
+    dom.render_immediate(&mut NoOpMutations);
+    dom.render_immediate(&mut NoOpMutations);
+
+    {
+        let current_counter = counter.borrow();
+        assert_eq!(current_counter.component, 2);
+        assert_eq!(current_counter.memo, 3);
+    }
+}