Explorar o código

fix memos during suspense

Evan Almloff hai 1 ano
pai
achega
3d7f419636

+ 3 - 3
packages/core/tests/suspense.rs

@@ -37,9 +37,9 @@ fn suspended_child() -> Element {
         panic!("Non-suspended task was polled");
     });
 
-    // // Memos should still work like normal
-    // let memo = use_memo(move || val * 2);
-    // assert_eq!(memo, val * 2);
+    // Memos should still work like normal
+    let memo = use_memo(move || val * 2);
+    assert_eq!(memo, val * 2);
 
     if val() < 3 {
         let task = spawn(async move {

+ 2 - 2
packages/core/tests/task.rs

@@ -53,7 +53,7 @@ async fn running_async() {
 #[tokio::test]
 async fn yield_now_works() {
     thread_local! {
-        static SEQUENCE: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
+        static SEQUENCE: std::cell::RefCell<Vec<usize>> = const { std::cell::RefCell::new(Vec::new()) };
     }
 
     fn app() -> Element {
@@ -88,7 +88,7 @@ async fn yield_now_works() {
 #[tokio::test]
 async fn flushing() {
     thread_local! {
-        static SEQUENCE: std::cell::RefCell<Vec<usize>> = std::cell::RefCell::new(Vec::new());
+        static SEQUENCE: std::cell::RefCell<Vec<usize>> = const { std::cell::RefCell::new(Vec::new()) };
         static BROADCAST: (tokio::sync::broadcast::Sender<()>, tokio::sync::broadcast::Receiver<()>) = tokio::sync::broadcast::channel(1);
     }
 

+ 3 - 0
packages/signals/src/copy_value.rs

@@ -237,14 +237,17 @@ impl<T: 'static, S: Storage<T>> Writable for CopyValue<T, S> {
         S::try_map_mut(mut_, f)
     }
 
+    #[track_caller]
     fn try_write(&mut self) -> Result<Self::Mut<T>, generational_box::BorrowMutError> {
         self.value.try_write()
     }
 
+    #[track_caller]
     fn write(&mut self) -> Self::Mut<T> {
         self.value.write()
     }
 
+    #[track_caller]
     fn set(&mut self, value: T) {
         self.value.set(value);
     }

+ 54 - 7
packages/signals/src/memo.rs

@@ -1,6 +1,7 @@
 use crate::write::Writable;
 use crate::{read::Readable, ReactiveContext, ReadableRef, Signal};
 use crate::{CopyValue, ReadOnlySignal};
+use std::rc::Rc;
 use std::{
     cell::RefCell,
     ops::Deref,
@@ -11,6 +12,33 @@ use std::{
 use dioxus_core::prelude::*;
 use futures_util::StreamExt;
 use generational_box::UnsyncStorage;
+use once_cell::sync::OnceCell;
+
+/// A thread local that can only be read from the thread it was created on.
+pub struct ThreadLocal<T> {
+    value: T,
+    owner: std::thread::ThreadId,
+}
+
+impl<T> ThreadLocal<T> {
+    /// Create a new thread local.
+    pub fn new(value: T) -> Self {
+        ThreadLocal {
+            value,
+            owner: std::thread::current().id(),
+        }
+    }
+
+    /// Get the value of the thread local.
+    pub fn get(&self) -> Option<&T> {
+        (self.owner == std::thread::current().id()).then_some(&self.value)
+    }
+}
+
+// SAFETY: This is safe because the thread local can only be read from the thread it was created on.
+unsafe impl<T> Send for ThreadLocal<T> {}
+unsafe impl<T> Sync for ThreadLocal<T> {}
+
 struct UpdateInformation<T> {
     dirty: Arc<AtomicBool>,
     callback: RefCell<Box<dyn FnMut() -> T>>,
@@ -41,11 +69,26 @@ impl<T: 'static> Memo<T> {
         let dirty = Arc::new(AtomicBool::new(true));
         let (tx, mut rx) = futures_channel::mpsc::unbounded();
 
+        let myself: Rc<OnceCell<Memo<T>>> = Rc::new(OnceCell::new());
+        let thread_local = ThreadLocal::new(myself.clone());
+
         let callback = {
             let dirty = dirty.clone();
-            move || {
-                dirty.store(true, std::sync::atomic::Ordering::Relaxed);
-                tx.unbounded_send(()).unwrap();
+            move || match thread_local.get() {
+                Some(memo) => match memo.get() {
+                    Some(memo) => {
+                        memo.recompute();
+                    }
+                    None => {
+                        tracing::error!("Memo was not initialized in the same thread it was created in. This is likely a bug in dioxus");
+                        dirty.store(true, std::sync::atomic::Ordering::Relaxed);
+                        let _ = tx.unbounded_send(());
+                    }
+                },
+                None => {
+                    dirty.store(true, std::sync::atomic::Ordering::Relaxed);
+                    let _ = tx.unbounded_send(());
+                }
             }
         };
         let rc = ReactiveContext::new_with_callback(
@@ -55,8 +98,9 @@ impl<T: 'static> Memo<T> {
         );
 
         // Create a new signal in that context, wiring up its dependencies and subscribers
-        let value = rc.run_in(&mut f);
-        let recompute = RefCell::new(Box::new(f) as Box<dyn FnMut() -> T>);
+        let mut recompute = move || rc.run_in(&mut f);
+        let value = recompute();
+        let recompute = RefCell::new(Box::new(recompute) as Box<dyn FnMut() -> T>);
         let update = CopyValue::new(UpdateInformation {
             dirty,
             callback: recompute,
@@ -67,9 +111,12 @@ impl<T: 'static> Memo<T> {
             inner: state,
             update,
         };
+        let _ = myself.set(memo);
 
         spawn(async move {
             while rx.next().await.is_some() {
+                // Remove any pending updates
+                while rx.try_next().is_ok() {}
                 memo.recompute();
             }
         });
@@ -78,6 +125,7 @@ impl<T: 'static> Memo<T> {
     }
 
     /// Rerun the computation and update the value of the memo if the result has changed.
+    #[tracing::instrument(skip(self))]
     fn recompute(&self)
     where
         T: PartialEq,
@@ -89,8 +137,7 @@ impl<T: 'static> Memo<T> {
         if new_value != *peak {
             drop(peak);
             let mut copy = self.inner;
-            let mut write = copy.write();
-            *write = new_value;
+            copy.set(new_value);
             update_write
                 .dirty
                 .store(false, std::sync::atomic::Ordering::Relaxed);

+ 5 - 3
packages/signals/src/reactive_context.rs

@@ -24,10 +24,12 @@ thread_local! {
 
 impl std::fmt::Display for ReactiveContext {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let read = self.inner.read();
         #[cfg(debug_assertions)]
-        return write!(f, "ReactiveContext created at {}", read.origin);
-        #[cfg(not(debug_assertions))]
+        {
+            if let Ok(read) = self.inner.try_read() {
+                return write!(f, "ReactiveContext created at {}", read.origin);
+            }
+        }
         write!(f, "ReactiveContext")
     }
 }

+ 4 - 2
packages/signals/src/signal.rs

@@ -150,8 +150,10 @@ impl<T: 'static, S: Storage<SignalData<T>>> Signal<T, S> {
         {
             let inner = self.inner.read();
 
-            let mut subscribers = inner.subscribers.lock().unwrap();
-            subscribers.retain(|reactive_context| reactive_context.mark_dirty())
+            // We cannot hold the subscribers lock while calling mark_dirty, because mark_dirty can run user code which may cause a new subscriber to be added. If we hold the lock, we will deadlock.
+            let mut subscribers = std::mem::take(&mut *inner.subscribers.lock().unwrap());
+            subscribers.retain(|reactive_context| reactive_context.mark_dirty());
+            *inner.subscribers.lock().unwrap() = subscribers;
         }
     }