浏览代码

fix deadlock and add more tests

Evan Almloff 2 年之前
父节点
当前提交
7d3ac26fce
共有 1 个文件被更改,包括 250 次插入43 次删除
  1. 250 43
      packages/native-core/src/passes.rs

+ 250 - 43
packages/native-core/src/passes.rs

@@ -1,15 +1,14 @@
 use crossbeam_deque::{Injector, Stealer, Worker};
 use crossbeam_deque::{Injector, Stealer, Worker};
-use parking_lot::RwLock;
+use parking_lot::{Condvar, Mutex, RwLock};
 use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
 use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
 use std::hash::BuildHasherDefault;
 use std::hash::BuildHasherDefault;
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::sync::Arc;
 use std::sync::Arc;
 use std::thread;
 use std::thread;
-use std::time::Duration;
 
 
 use crate::tree::{NodeId, SharedView, TreeView};
 use crate::tree::{NodeId, SharedView, TreeView};
 
 
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
 pub struct PassId(u64);
 pub struct PassId(u64);
 
 
 pub trait UpwardPass<T> {
 pub trait UpwardPass<T> {
@@ -124,7 +123,7 @@ fn get_pass<T, Tr: TreeView<T>>(
     dirty_nodes: &DirtyNodeStates,
     dirty_nodes: &DirtyNodeStates,
     shared_view: &mut SharedView<T, Tr>,
     shared_view: &mut SharedView<T, Tr>,
     global: &Injector<NodeId>,
     global: &Injector<NodeId>,
-    current_pass: &RwLock<Option<AnyPass<T>>>,
+    current_pass: &mut Option<AnyPass<T>>,
 ) {
 ) {
     for i in 0..passes.len() {
     for i in 0..passes.len() {
         if passes[i]
         if passes[i]
@@ -134,6 +133,7 @@ fn get_pass<T, Tr: TreeView<T>>(
         {
         {
             let pass = passes.remove(i);
             let pass = passes.remove(i);
             let pass_id = pass.pass_id();
             let pass_id = pass.pass_id();
+            println!("Running pass {:?}", pass_id);
             resolved_passes.insert(pass_id);
             resolved_passes.insert(pass_id);
             match pass {
             match pass {
                 AnyPass::Upward(pass) => {
                 AnyPass::Upward(pass) => {
@@ -164,13 +164,21 @@ fn get_pass<T, Tr: TreeView<T>>(
                     for node in dirty_nodes.all_dirty(pass_id) {
                     for node in dirty_nodes.all_dirty(pass_id) {
                         global.push(node);
                         global.push(node);
                     }
                     }
-                    current_pass.write().replace(pass);
+                    println!(
+                        "Task: {:?} {:?}",
+                        pass_id,
+                        dirty_nodes.all_dirty(pass_id).collect::<Vec<_>>()
+                    );
+                    current_pass.replace(pass);
                 }
                 }
             }
             }
-
-            break;
+            return;
         }
         }
     }
     }
+    panic!(
+        "No pass found with all dependancies resolved in {:?}",
+        passes.iter().map(|p| p.pass_id()).collect::<Vec<_>>()
+    );
 }
 }
 
 
 pub fn resolve_passes<T>(
 pub fn resolve_passes<T>(
@@ -178,6 +186,7 @@ pub fn resolve_passes<T>(
     starting_nodes: FxHashMap<NodeId, FxHashSet<PassId>>,
     starting_nodes: FxHashMap<NodeId, FxHashSet<PassId>>,
     mut passes: Vec<AnyPass<T>>,
     mut passes: Vec<AnyPass<T>>,
 ) {
 ) {
+    assert!(!passes.is_empty());
     let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
     let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
     let global = Injector::default();
     let global = Injector::default();
 
 
@@ -191,69 +200,105 @@ pub fn resolve_passes<T>(
     let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
     let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
 
 
     thread::scope(|s| {
     thread::scope(|s| {
-        get_pass(
-            &mut passes,
-            &mut resolved_passes,
-            &dirty_nodes,
-            &mut shared_view,
-            &global,
-            &current_pass,
-        );
+        {
+            let mut write = current_pass.write();
+            get_pass(
+                &mut passes,
+                &mut resolved_passes,
+                &dirty_nodes,
+                &mut shared_view,
+                &global,
+                &mut *write,
+            );
+        }
         let global = &global;
         let global = &global;
         let stealers = &stealers;
         let stealers = &stealers;
+        let mut thread_handles = Vec::new();
+        let threads_finished = Arc::new((Mutex::new(0), Condvar::new()));
         for (_, w) in (0..core_count).zip(workers.into_iter()) {
         for (_, w) in (0..core_count).zip(workers.into_iter()) {
             let mut shared_view = shared_view.clone();
             let mut shared_view = shared_view.clone();
             let current_pass = current_pass.clone();
             let current_pass = current_pass.clone();
             let dirty_nodes = dirty_nodes.clone();
             let dirty_nodes = dirty_nodes.clone();
-            s.spawn(move || {
-                while let Some(current_pass) = &*current_pass.read() {
-                    match current_pass {
-                        AnyPass::Upward(_) => {
-                            todo!("Upward passes are single threaded")
-                        }
-                        AnyPass::Node(pass) => {
-                            // Node passes are the easiest to parallelize. We just run the pass on each node.
-                            while let Some(id) = find_task(&w, global, stealers) {
-                                let node = shared_view.get_mut(id).unwrap();
-                                if pass.pass(node) {
-                                    for dependant in pass.dependants() {
-                                        dirty_nodes.insert(*dependant, id);
+            let threads_finished = threads_finished.clone();
+            thread_handles.push(s.spawn(move || {
+                loop {
+                    let read = current_pass.read();
+                    if let Some(current_pass) = &*read {
+                        match current_pass {
+                            AnyPass::Upward(_) => {
+                                todo!("Upward passes are single threaded")
+                            }
+                            AnyPass::Node(pass) => {
+                                // Node passes are the easiest to parallelize. We just run the pass on each node.
+                                while let Some(id) = find_task(&w, global, stealers) {
+                                    let node = shared_view.get_mut(id).unwrap();
+                                    if pass.pass(node) {
+                                        for dependant in pass.dependants() {
+                                            dirty_nodes.insert(*dependant, id);
+                                        }
                                     }
                                     }
                                 }
                                 }
                             }
                             }
-                        }
-                        AnyPass::Downward(pass) => {
-                            // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load.
-                            while let Some(id) = find_task(&w, global, stealers) {
-                                let (node, parent) = shared_view.node_parent_mut(id).unwrap();
-                                if pass.pass(node, parent) {
-                                    for id in shared_view.children_ids(id).unwrap() {
-                                        for dependant in pass.dependants() {
-                                            dirty_nodes.insert(*dependant, *id);
+                            AnyPass::Downward(pass) => {
+                                // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load.
+                                while let Some(id) = find_task(&w, global, stealers) {
+                                    let (node, parent) = shared_view.node_parent_mut(id).unwrap();
+                                    if pass.pass(node, parent) {
+                                        for id in shared_view.children_ids(id).unwrap() {
+                                            for dependant in pass.dependants() {
+                                                dirty_nodes.insert(*dependant, *id);
+                                            }
+                                            w.push(*id);
                                         }
                                         }
-                                        w.push(*id);
                                     }
                                     }
                                 }
                                 }
                             }
                             }
                         }
                         }
+                        // unblock the rwlock
+                        drop(read);
+                        {
+                            let (count, cvar) = &*threads_finished;
+                            let mut lock = count.lock();
+                            *lock += 1;
+                            if *lock == core_count {
+                                cvar.notify_all();
+                            }
+                        }
+                        // wait for the main thread to pick the next pass
+                        thread::park();
+                    } else {
+                        break;
                     }
                     }
                 }
                 }
-            });
+            }));
         }
         }
+        let wait_for_thread_to_finish = move || {
+            // this will block until all threads are done with the current pass
+            let (count, cvar) = &*threads_finished;
+            let mut count = count.lock();
+            cvar.wait(&mut count);
+        };
         while !passes.is_empty() {
         while !passes.is_empty() {
-            while !stealers.iter().all(|s| s.is_empty()) {
-                std::thread::sleep(Duration::from_millis(50));
-            }
+            wait_for_thread_to_finish();
+            let mut write = current_pass.write();
+            println!("Threads finished");
             get_pass(
             get_pass(
                 &mut passes,
                 &mut passes,
                 &mut resolved_passes,
                 &mut resolved_passes,
                 &dirty_nodes,
                 &dirty_nodes,
                 &mut shared_view,
                 &mut shared_view,
                 global,
                 global,
-                &current_pass,
+                &mut *write,
             );
             );
+            // notify all threads to start the next pass
+            for thread in &thread_handles {
+                thread.thread().unpark();
+            }
         }
         }
         *current_pass.write() = None;
         *current_pass.write() = None;
+        for thread in &thread_handles {
+            thread.thread().unpark();
+        }
     });
     });
 }
 }
 
 
@@ -318,6 +363,75 @@ fn node_pass() {
     assert_eq!(tree.get(tree.root()).unwrap(), &1);
     assert_eq!(tree.get(tree.root()).unwrap(), &1);
 }
 }
 
 
+#[test]
+fn dependant_node_pass() {
+    use crate::tree::{Tree, TreeLike};
+    let mut tree = Tree::new(0);
+    let parent = tree.root();
+    let child1 = tree.create_node(1);
+    tree.add_child(parent, child1);
+    let grandchild1 = tree.create_node(3);
+    tree.add_child(child1, grandchild1);
+    let child2 = tree.create_node(2);
+    tree.add_child(parent, child2);
+    let grandchild2 = tree.create_node(4);
+    tree.add_child(child2, grandchild2);
+
+    struct AddPass;
+
+    impl NodePass<i32> for AddPass {
+        fn pass_id(&self) -> PassId {
+            PassId(0)
+        }
+
+        fn dependancies(&self) -> &'static [PassId] {
+            &[PassId(1)]
+        }
+
+        fn dependants(&self) -> &'static [PassId] {
+            &[]
+        }
+
+        fn pass(&self, node: &mut i32) -> bool {
+            println!("AddPass: {}", node);
+            *node += 1;
+            true
+        }
+    }
+
+    struct SubtractPass;
+
+    impl NodePass<i32> for SubtractPass {
+        fn pass_id(&self) -> PassId {
+            PassId(1)
+        }
+
+        fn dependancies(&self) -> &'static [PassId] {
+            &[]
+        }
+
+        fn dependants(&self) -> &'static [PassId] {
+            &[PassId(0)]
+        }
+
+        fn pass(&self, node: &mut i32) -> bool {
+            println!("SubtractPass: {}", node);
+            *node -= 1;
+            true
+        }
+    }
+
+    let passes = vec![
+        AnyPass::Node(Box::new(AddPass)),
+        AnyPass::Node(Box::new(SubtractPass)),
+    ];
+    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
+    dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
+    resolve_passes(&mut tree, dirty_nodes, passes);
+
+    assert_eq!(*tree.get(tree.root()).unwrap(), 0);
+}
+
 #[test]
 #[test]
 fn down_pass() {
 fn down_pass() {
     use crate::tree::{Tree, TreeLike};
     use crate::tree::{Tree, TreeLike};
@@ -367,6 +481,99 @@ fn down_pass() {
     assert_eq!(tree.get(grandchild2).unwrap(), &3);
     assert_eq!(tree.get(grandchild2).unwrap(), &3);
 }
 }
 
 
+#[test]
+fn dependant_down_pass() {
+    use crate::tree::{Tree, TreeLike};
+    let mut tree = Tree::new(1);
+    let parent = tree.root();
+    let child1 = tree.create_node(1);
+    tree.add_child(parent, child1);
+    let grandchild1 = tree.create_node(1);
+    tree.add_child(child1, grandchild1);
+    let child2 = tree.create_node(1);
+    tree.add_child(parent, child2);
+    let grandchild2 = tree.create_node(1);
+    tree.add_child(child2, grandchild2);
+
+    struct AddPass;
+
+    impl DownwardPass<i32> for AddPass {
+        fn pass_id(&self) -> PassId {
+            PassId(0)
+        }
+
+        fn dependancies(&self) -> &'static [PassId] {
+            &[PassId(1)]
+        }
+
+        fn dependants(&self) -> &'static [PassId] {
+            &[]
+        }
+
+        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
+            if let Some(parent) = parent {
+                *node += *parent;
+            }
+            true
+        }
+    }
+
+    struct SubtractPass;
+
+    impl DownwardPass<i32> for SubtractPass {
+        fn pass_id(&self) -> PassId {
+            PassId(1)
+        }
+
+        fn dependancies(&self) -> &'static [PassId] {
+            &[]
+        }
+
+        fn dependants(&self) -> &'static [PassId] {
+            &[PassId(0)]
+        }
+
+        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
+            if let Some(parent) = parent {
+                *node -= *parent;
+            }
+            true
+        }
+    }
+
+    let passes = vec![
+        AnyPass::Downward(Box::new(AddPass)),
+        AnyPass::Downward(Box::new(SubtractPass)),
+    ];
+    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
+    dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
+    resolve_passes(&mut tree, dirty_nodes, passes);
+
+    // Tree before:
+    // 1=\
+    //   1=\
+    //     1
+    //   1=\
+    //     1
+    // Tree after subtract:
+    // 1=\
+    //   0=\
+    //     1
+    //   0=\
+    //     1
+    // Tree after add:
+    // 1=\
+    //   1=\
+    //     2
+    //   1=\
+    //     2
+    assert_eq!(tree.get(tree.root()).unwrap(), &1);
+    assert_eq!(tree.get(child1).unwrap(), &1);
+    assert_eq!(tree.get(grandchild1).unwrap(), &2);
+    assert_eq!(tree.get(child2).unwrap(), &1);
+    assert_eq!(tree.get(grandchild2).unwrap(), &2);
+}
+
 #[test]
 #[test]
 fn up_pass() {
 fn up_pass() {
     use crate::tree::{Tree, TreeLike};
     use crate::tree::{Tree, TreeLike};