瀏覽代碼

only poll suspended futures

Evan Almloff 1 年之前
父節點
當前提交
c9603ea984

+ 3 - 3
packages/core/src/global_context.rs

@@ -50,9 +50,9 @@ pub fn provide_root_context<T: 'static + Clone>(value: T) -> T {
         .expect("to be in a dioxus runtime")
 }
 
-/// Suspends the current component
-pub fn suspend() -> Option<Element> {
-    Runtime::with_current_scope(|cx| cx.suspend());
+/// Suspended the current component on a specific task and then return None
+pub fn suspend(task: Task) -> Element {
+    Runtime::with_current_scope(|cx| cx.suspend(task));
     None
 }
 

+ 8 - 2
packages/core/src/runtime.rs

@@ -1,3 +1,5 @@
+use rustc_hash::FxHashSet;
+
 use crate::{
     innerlude::{LocalTask, SchedulerMsg},
     render_signal::RenderSignal,
@@ -24,11 +26,14 @@ pub struct Runtime {
     // We use this to track the current task
     pub(crate) current_task: Cell<Option<Task>>,
 
-    pub(crate) rendering: Cell<bool>,
-
     /// Tasks created with cx.spawn
     pub(crate) tasks: RefCell<slab::Slab<Rc<LocalTask>>>,
 
+    // Currently suspended tasks
+    pub(crate) suspended_tasks: RefCell<FxHashSet<Task>>,
+
+    pub(crate) rendering: Cell<bool>,
+
     pub(crate) sender: futures_channel::mpsc::UnboundedSender<SchedulerMsg>,
 
     // Synchronous tasks need to be run after the next render. The virtual dom stores a list of those tasks to send a signal to them when the next render is done.
@@ -45,6 +50,7 @@ impl Runtime {
             scope_stack: Default::default(),
             current_task: Default::default(),
             tasks: Default::default(),
+            suspended_tasks: Default::default(),
         })
     }
 

+ 3 - 5
packages/core/src/scope_arena.rs

@@ -41,7 +41,6 @@ impl VirtualDom {
         let new_nodes = {
             let context = scope.state();
 
-            context.suspended.set(false);
             context.hook_index.set(0);
 
             // Run all pre-render hooks
@@ -70,12 +69,11 @@ impl VirtualDom {
         self.dirty_scopes
             .remove(&ScopeOrder::new(context.height, scope_id));
 
-        if context.suspended.get() {
+        if let Some(task) = context.last_suspendable_task.take() {
             if matches!(new_nodes, RenderReturn::Aborted(_)) {
-                self.suspended_scopes.insert(context.id);
+                tracing::trace!("Suspending {:?} on {:?}", scope_id, task);
+                self.runtime.suspended_tasks.borrow_mut().insert(task);
             }
-        } else if !self.suspended_scopes.is_empty() {
-            _ = self.suspended_scopes.remove(&context.id);
         }
 
         self.runtime.scope_stack.borrow_mut().pop();

+ 9 - 8
packages/core/src/scope_context.rs

@@ -16,13 +16,14 @@ pub(crate) struct Scope {
     pub(crate) parent_id: Option<ScopeId>,
     pub(crate) height: u32,
     pub(crate) render_count: Cell<usize>,
-    pub(crate) suspended: Cell<bool>,
 
     // Note: the order of the hook and context fields is important. The hooks field must be dropped before the contexts field in case a hook drop implementation tries to access a context.
     pub(crate) hooks: RefCell<Vec<Box<dyn Any>>>,
     pub(crate) hook_index: Cell<usize>,
     pub(crate) shared_contexts: RefCell<Vec<Box<dyn Any>>>,
     pub(crate) spawned_tasks: RefCell<FxHashSet<Task>>,
+    /// The task that was last spawned that may suspend. We use this task to check what task to suspend in the event of an early None return from a component
+    pub(crate) last_suspendable_task: Cell<Option<Task>>,
     pub(crate) before_render: RefCell<Vec<Box<dyn FnMut()>>>,
     pub(crate) after_render: RefCell<Vec<Box<dyn FnMut()>>>,
 }
@@ -40,9 +41,9 @@ impl Scope {
             parent_id,
             height,
             render_count: Cell::new(0),
-            suspended: Cell::new(false),
             shared_contexts: RefCell::new(vec![]),
             spawned_tasks: RefCell::new(FxHashSet::default()),
+            last_suspendable_task: Cell::new(None),
             hooks: RefCell::new(vec![]),
             hook_index: Cell::new(0),
             before_render: RefCell::new(vec![]),
@@ -241,9 +242,9 @@ impl Scope {
         Runtime::with(|rt| rt.spawn(self.id, fut)).expect("Runtime to exist")
     }
 
-    /// Mark this component as suspended and then return None
-    pub fn suspend(&self) -> Option<Element> {
-        self.suspended.set(true);
+    /// Mark this component as suspended on a specific task and then return None
+    pub fn suspend(&self, task: Task) -> Option<Element> {
+        self.last_suspendable_task.set(Some(task));
         None
     }
 
@@ -340,10 +341,10 @@ impl ScopeId {
             .expect("to be in a dioxus runtime")
     }
 
-    /// Suspends the current component
-    pub fn suspend(self) -> Option<Element> {
+    /// Suspended a component on a specific task and then return None
+    pub fn suspend(self, task: Task) -> Option<Element> {
         Runtime::with_scope(self, |cx| {
-            cx.suspend();
+            cx.suspend(task);
         });
         None
     }

+ 4 - 0
packages/core/src/tasks.rs

@@ -173,6 +173,9 @@ impl Runtime {
 
             // Remove it from the scheduler
             self.tasks.borrow_mut().try_remove(id.0);
+
+            // Remove it from the suspended tasks
+            self.suspended_tasks.borrow_mut().remove(&id);
         }
 
         // Remove the scope from the stack
@@ -187,6 +190,7 @@ impl Runtime {
     ///
     /// This does not abort the task, so you'll want to wrap it in an abort handle if that's important to you
     pub(crate) fn remove_task(&self, id: Task) -> Option<Rc<LocalTask>> {
+        self.suspended_tasks.borrow_mut().remove(&id);
         self.tasks.borrow_mut().try_remove(id.0)
     }
 }

+ 85 - 25
packages/core/src/virtual_dom.rs

@@ -18,7 +18,7 @@ use crate::{
     AttributeValue, ComponentFunction, Element, Event, Mutations,
 };
 use futures_util::StreamExt;
-use rustc_hash::{FxHashMap, FxHashSet};
+use rustc_hash::FxHashMap;
 use slab::Slab;
 use std::{any::Any, rc::Rc};
 use tracing::instrument;
@@ -201,9 +201,6 @@ pub struct VirtualDom {
 
     pub(crate) runtime: Rc<Runtime>,
 
-    // Currently suspended scopes
-    pub(crate) suspended_scopes: FxHashSet<ScopeId>,
-
     rx: futures_channel::mpsc::UnboundedReceiver<SchedulerMsg>,
 }
 
@@ -319,7 +316,6 @@ impl VirtualDom {
             queued_templates: Default::default(),
             elements: Default::default(),
             mounts: Default::default(),
-            suspended_scopes: Default::default(),
         };
 
         let root = dom.new_scope(Box::new(root), "app");
@@ -448,13 +444,6 @@ impl VirtualDom {
     /// ```
     #[instrument(skip(self), level = "trace", name = "VirtualDom::wait_for_work")]
     pub async fn wait_for_work(&mut self) {
-        // And then poll the futures
-        self.poll_tasks().await;
-    }
-
-    /// Poll the scheduler for any work
-    #[instrument(skip(self), level = "trace", name = "VirtualDom::poll_tasks")]
-    async fn poll_tasks(&mut self) {
         loop {
             // Process all events - Scopes are marked dirty, etc
             // Sometimes when wakers fire we get a slew of updates at once, so its important that we drain this completely
@@ -469,17 +458,22 @@ impl VirtualDom {
             let _runtime = RuntimeGuard::new(self.runtime.clone());
 
             // There isn't any more work we can do synchronously. Wait for any new work to be ready
-            match self.rx.next().await.expect("channel should never close") {
-                SchedulerMsg::Immediate(id) => self.mark_dirty(id),
-                SchedulerMsg::TaskNotified(id) => {
-                    // Instead of running the task immediately, we insert it into the runtime's task queue.
-                    // The task may be marked dirty at the same time as the scope that owns the task is dropped.
-                    self.mark_task_dirty(id);
-                }
-            };
+            self.wait_for_event().await;
         }
     }
 
+    /// Wait for the next event to trigger and add it to the queue
+    async fn wait_for_event(&mut self) {
+        match self.rx.next().await.expect("channel should never close") {
+            SchedulerMsg::Immediate(id) => self.mark_dirty(id),
+            SchedulerMsg::TaskNotified(id) => {
+                // Instead of running the task immediately, we insert it into the runtime's task queue.
+                // The task may be marked dirty at the same time as the scope that owns the task is dropped.
+                self.mark_task_dirty(id);
+            }
+        };
+    }
+
     /// Queue any pending events
     fn queue_events(&mut self) {
         // Prevent a task from deadlocking the runtime by repeatedly queueing itself
@@ -494,7 +488,6 @@ impl VirtualDom {
     /// Process all events in the queue until there are no more left
     #[instrument(skip(self), level = "trace", name = "VirtualDom::process_events")]
     pub fn process_events(&mut self) {
-        let _runtime = RuntimeGuard::new(self.runtime.clone());
         self.queue_events();
 
         // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
@@ -502,6 +495,14 @@ impl VirtualDom {
             return;
         }
 
+        self.poll_tasks()
+    }
+
+    /// Poll any queued tasks
+    #[instrument(skip(self), level = "trace", name = "VirtualDom::poll_tasks")]
+    fn poll_tasks(&mut self) {
+        // Make sure we set the runtime since we're running user code
+        let _runtime = RuntimeGuard::new(self.runtime.clone());
         // Next, run any queued tasks
         // We choose not to poll the deadline since we complete pretty quickly anyways
         while let Some(task) = self.dirty_scopes.pop_task() {
@@ -617,7 +618,6 @@ impl VirtualDom {
             {
                 let _runtime = RuntimeGuard::new(self.runtime.clone());
                 // Then, poll any tasks that might be pending in the scope
-                // This will run effects, so this **must** be done after the scope is diffed
                 for task in work.tasks {
                     let _ = self.runtime.handle_task_wakeup(task);
                 }
@@ -649,15 +649,75 @@ impl VirtualDom {
     #[instrument(skip(self), level = "trace", name = "VirtualDom::wait_for_suspense")]
     pub async fn wait_for_suspense(&mut self) {
         loop {
-            if self.suspended_scopes.is_empty() {
+            if self.runtime.suspended_tasks.borrow().is_empty() {
                 break;
             }
 
             // Wait for a work to be ready (IE new suspense leaves to pop up)
-            self.poll_tasks().await;
+            'wait_for_work: loop {
+                // Process all events - Scopes are marked dirty, etc
+                // Sometimes when wakers fire we get a slew of updates at once, so its important that we drain this completely
+                self.queue_events();
+
+                // Now that we have collected all queued work, we should check if we have any dirty scopes. If there are not, then we can poll any queued futures
+                if self.dirty_scopes.has_dirty_scopes() {
+                    println!("dirty scopes");
+                    break;
+                }
+
+                {
+                    // Make sure we set the runtime since we're running user code
+                    let _runtime = RuntimeGuard::new(self.runtime.clone());
+                    // Next, run any queued tasks
+                    // We choose not to poll the deadline since we complete pretty quickly anyways
+                    while let Some(task) = self.dirty_scopes.pop_task() {
+                        // If the scope doesn't exist for whatever reason, then we should skip it
+                        if !self.scopes.contains(task.order.id.0) {
+                            continue;
+                        }
+
+                        // Then poll any tasks that might be pending
+                        let tasks = task.tasks_queued.into_inner();
+                        for task in tasks {
+                            if self.runtime.suspended_tasks.borrow().contains(&task) {
+                                let _ = self.runtime.handle_task_wakeup(task);
+                                // Running that task, may mark a scope higher up as dirty. If it does, return from the function early
+                                self.queue_events();
+                                if self.dirty_scopes.has_dirty_scopes() {
+                                    break 'wait_for_work;
+                                }
+                            }
+                        }
+                    }
+                }
+
+                self.wait_for_event().await;
+            }
 
             // Render whatever work needs to be rendered, unlocking new futures and suspense leaves
-            self.render_immediate(&mut NoOpMutations);
+            while let Some(work) = self.dirty_scopes.pop_work() {
+                // If the scope doesn't exist for whatever reason, then we should skip it
+                if !self.scopes.contains(work.scope.id.0) {
+                    continue;
+                }
+
+                {
+                    let _runtime = RuntimeGuard::new(self.runtime.clone());
+                    // Then, poll any tasks that might be pending in the scope
+                    for task in work.tasks {
+                        // During suspense, we only want to run tasks that are suspended
+                        if self.runtime.suspended_tasks.borrow().contains(&task) {
+                            let _ = self.runtime.handle_task_wakeup(task);
+                        }
+                    }
+                    // If the scope is dirty, run the scope and get the mutations
+                    if work.rerun_scope {
+                        let new_nodes = self.run_scope(work.scope.id);
+
+                        self.diff_scope(&mut NoOpMutations, work.scope.id, new_nodes);
+                    }
+                }
+            }
         }
     }
 

+ 28 - 3
packages/core/tests/suspense.rs

@@ -1,9 +1,10 @@
 use dioxus::prelude::*;
+use std::future::poll_fn;
+use std::task::Poll;
 
 #[test]
 fn suspense_resolves() {
     // wait just a moment, not enough time for the boundary to resolve
-
     tokio::runtime::Builder::new_current_thread()
         .build()
         .unwrap()
@@ -31,11 +32,35 @@ fn app() -> Element {
 fn suspended_child() -> Element {
     let mut val = use_signal(|| 0);
 
+    // Tasks that are not suspended should never be polled
+    spawn(async move {
+        panic!("Non-suspended task was polled");
+    });
+
+    // // Memos should still work like normal
+    // let memo = use_memo(move || val * 2);
+    // assert_eq!(memo, val * 2);
+
     if val() < 3 {
-        spawn(async move {
+        let task = spawn(async move {
+            // Poll each task 3 times
+            let mut count = 0;
+            poll_fn(|cx| {
+                println!("polling... {}", count);
+                if count < 3 {
+                    count += 1;
+                    cx.waker().wake_by_ref();
+                    Poll::Pending
+                } else {
+                    Poll::Ready(())
+                }
+            })
+            .await;
+
+            println!("waiting... {}", val);
             val += 1;
         });
-        suspend()?;
+        suspend(task)?;
     }
 
     rsx!("child")