Przeglądaj źródła

make passes exicute in parallel instead of executing invidual passes in parellel

Evan Almloff 2 lat temu
rodzic
commit
51f643c5dc
1 zmienionych plików z 329 dodań i 282 usunięć
  1. 329 282
      packages/native-core/src/passes.rs

+ 329 - 282
packages/native-core/src/passes.rs

@@ -1,61 +1,63 @@
 use crossbeam_deque::{Injector, Stealer, Worker};
-use parking_lot::{Condvar, Mutex, RwLock};
 use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
+use std::collections::{BTreeMap, HashMap};
 use std::hash::BuildHasherDefault;
+use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign};
+use std::os::raw;
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::sync::Arc;
-use std::thread;
 
-use crate::tree::{NodeId, SharedView, TreeView};
+use crate::tree::{NodeId, TreeView};
 
-#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
-pub struct PassId(u64);
-
-pub trait UpwardPass<T> {
-    fn pass_id(&self) -> PassId;
-    fn dependancies(&self) -> &'static [PassId];
-    fn dependants(&self) -> &'static [PassId];
-    fn pass<'a>(&self, node: &mut T, children: &mut dyn Iterator<Item = &'a mut T>) -> bool;
-}
-
-pub trait DownwardPass<T> {
-    fn pass_id(&self) -> PassId;
-    fn dependancies(&self) -> &'static [PassId];
-    fn dependants(&self) -> &'static [PassId];
-    fn pass(&self, node: &mut T, parent: Option<&mut T>) -> bool;
-}
-
-pub trait NodePass<T> {
-    fn pass_id(&self) -> PassId;
-    fn dependancies(&self) -> &'static [PassId];
-    fn dependants(&self) -> &'static [PassId];
-    fn pass(&self, node: &mut T) -> bool;
+#[derive(Debug, Clone, PartialEq, Eq, Default)]
+pub struct DirtyNodes {
+    map: BTreeMap<u16, FxHashSet<NodeId>>,
 }
 
-pub enum AnyPass<T> {
-    Upward(Box<dyn UpwardPass<T> + Send + Sync>),
-    Downward(Box<dyn DownwardPass<T> + Send + Sync>),
-    Node(Box<dyn NodePass<T> + Send + Sync>),
-}
+impl DirtyNodes {
+    pub fn insert(&mut self, depth: u16, node_id: NodeId) {
+        self.map
+            .entry(depth)
+            .or_insert_with(FxHashSet::default)
+            .insert(node_id);
+    }
 
-impl<T> AnyPass<T> {
-    fn pass_id(&self) -> PassId {
-        match self {
-            Self::Upward(pass) => pass.pass_id(),
-            Self::Downward(pass) => pass.pass_id(),
-            Self::Node(pass) => pass.pass_id(),
+    fn pop_front(&mut self) -> Option<NodeId> {
+        let (&depth, values) = self.map.iter_mut().next()?;
+        let key = *values.iter().next()?;
+        let node_id = values.take(&key)?;
+        if values.is_empty() {
+            self.map.remove(&depth);
         }
+        Some(node_id)
     }
 
-    fn dependancies(&self) -> &'static [PassId] {
-        match self {
-            Self::Upward(pass) => pass.dependancies(),
-            Self::Downward(pass) => pass.dependancies(),
-            Self::Node(pass) => pass.dependancies(),
+    fn pop_back(&mut self) -> Option<NodeId> {
+        let (&depth, values) = self.map.iter_mut().rev().next()?;
+        let key = *values.iter().next()?;
+        let node_id = values.take(&key)?;
+        if values.is_empty() {
+            self.map.remove(&depth);
         }
+        Some(node_id)
     }
 }
 
+#[test]
+fn dirty_nodes() {
+    let mut dirty_nodes = DirtyNodes::default();
+
+    dirty_nodes.insert(1, NodeId(1));
+    dirty_nodes.insert(0, NodeId(0));
+    dirty_nodes.insert(2, NodeId(3));
+    dirty_nodes.insert(1, NodeId(2));
+
+    assert_eq!(dirty_nodes.pop_front(), Some(NodeId(0)));
+    assert!(matches!(dirty_nodes.pop_front(), Some(NodeId(1 | 2))));
+    assert!(matches!(dirty_nodes.pop_front(), Some(NodeId(1 | 2))));
+    assert_eq!(dirty_nodes.pop_front(), Some(NodeId(3)));
+}
+
 type FxDashMap<K, V> = dashmap::DashMap<K, V, BuildHasherDefault<FxHasher>>;
 
 #[derive(Default)]
@@ -112,247 +114,248 @@ impl DirtyNodeStates {
         }
     }
 
-    fn all_dirty(&self, pass_id: PassId) -> impl Iterator<Item = NodeId> + '_ {
+    fn all_dirty<T>(&self, pass_id: PassId, dirty_nodes: &mut DirtyNodes, tree: &impl TreeView<T>) {
         let pass_id = pass_id.0;
         let index = pass_id / 64;
         let bit = pass_id % 64;
         let encoded = 1 << bit;
-        self.dirty.iter().filter_map(move |entry| {
+        for entry in self.dirty.iter() {
             let node_id = entry.key();
             let dirty = entry.value();
             if let Some(atomic) = dirty.get(index as usize) {
                 if atomic.load(Ordering::Relaxed) & encoded != 0 {
-                    Some(*node_id)
-                } else {
-                    None
+                    dirty_nodes.insert(tree.height(*node_id).unwrap(), *node_id);
                 }
-            } else {
-                None
             }
-        })
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
+pub struct PassId(u64);
+
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Default)]
+pub struct MemberMask(u64);
+
+impl MemberMask {
+    pub fn overlaps(&self, other: Self) -> bool {
+        (*self & other).0 != 0
+    }
+}
+
+impl BitAndAssign for MemberMask {
+    fn bitand_assign(&mut self, rhs: Self) {
+        self.0 &= rhs.0;
+    }
+}
+
+impl BitAnd for MemberMask {
+    type Output = Self;
+
+    fn bitand(self, rhs: Self) -> Self::Output {
+        MemberMask(self.0 & rhs.0)
+    }
+}
+
+impl BitOrAssign for MemberMask {
+    fn bitor_assign(&mut self, rhs: Self) {
+        self.0 |= rhs.0;
+    }
+}
+
+impl BitOr for MemberMask {
+    type Output = Self;
+
+    fn bitor(self, rhs: Self) -> Self::Output {
+        Self(self.0 | rhs.0)
     }
 }
 
-fn get_pass<T, Tr: TreeView<T>>(
-    passes: &mut Vec<AnyPass<T>>,
-    resolved_passes: &mut FxHashSet<PassId>,
-    dirty_nodes: &DirtyNodeStates,
-    shared_view: &mut SharedView<T, Tr>,
-    global: &Injector<NodeId>,
-    current_pass: &mut Option<AnyPass<T>>,
+pub struct PassReturn {
+    progress: bool,
+    mark_dirty: bool,
+}
+
+pub trait Pass {
+    fn pass_id(&self) -> PassId;
+    fn dependancies(&self) -> &'static [PassId];
+    fn dependants(&self) -> &'static [PassId];
+    fn mask(&self) -> MemberMask;
+}
+
+pub trait UpwardPass<T>: Pass {
+    fn pass<'a>(&self, node: &mut T, children: &mut dyn Iterator<Item = &'a mut T>) -> PassReturn;
+}
+
+fn resolve_upward_pass<T, P: UpwardPass<T> + ?Sized>(
+    tree: &mut impl TreeView<T>,
+    pass: &P,
+    mut dirty: DirtyNodes,
+    dirty_states: &DirtyNodeStates,
 ) {
-    for i in 0..passes.len() {
-        if passes[i]
-            .dependancies()
-            .iter()
-            .all(|id| resolved_passes.contains(id))
-        {
-            let pass = passes.remove(i);
-            let pass_id = pass.pass_id();
-            println!("Running pass {:?}", pass_id);
-            resolved_passes.insert(pass_id);
-            match pass {
-                AnyPass::Upward(pass) => {
-                    // Upward passes are more difficult. Right now we limit them to only one thread.
-                    let worker = Worker::new_fifo();
-                    let mut queued_nodes = FxHashSet::default();
-                    for node in dirty_nodes.all_dirty(pass_id) {
-                        queued_nodes.insert(node);
-                        worker.push(node);
-                    }
-                    while let Some(id) = worker.pop() {
-                        let (node, mut children) = shared_view.parent_child_mut(id).unwrap();
-                        if pass.pass(node, &mut children) {
-                            drop(children);
-                            if let Some(id) = shared_view.parent_id(id) {
-                                for dependant in pass.dependants() {
-                                    dirty_nodes.insert(*dependant, id);
-                                }
-                                if !queued_nodes.contains(&id) {
-                                    queued_nodes.insert(id);
-                                    worker.push(id);
-                                }
-                            }
-                        }
+    while let Some(id) = dirty.pop_back() {
+        let (node, mut children) = tree.parent_child_mut(id).unwrap();
+        let result = pass.pass(node, &mut children);
+        drop(children);
+        if result.progress || result.mark_dirty {
+            if let Some(id) = tree.parent_id(id) {
+                if result.mark_dirty {
+                    for dependant in pass.dependants() {
+                        dirty_states.insert(*dependant, id);
                     }
                 }
-                AnyPass::Downward(pass) => {
-                    let mut sorted: Vec<_> = dirty_nodes.all_dirty(pass_id).collect();
-                    sorted.sort_unstable_by_key(|id| shared_view.height(*id));
-                    println!(
-                        "Task: {:?} {:?}",
-                        pass_id,
-                        sorted
-                            .iter()
-                            .map(|id| (id, shared_view.height(*id)))
-                            .collect::<Vec<_>>()
-                    );
-                    for node in sorted.into_iter() {
-                        global.push(node);
-                    }
-                    current_pass.replace(AnyPass::Downward(pass));
-                }
-                AnyPass::Node(pass) => {
-                    for node in dirty_nodes.all_dirty(pass_id) {
-                        global.push(node);
-                    }
-                    println!(
-                        "Task: {:?} {:?}",
-                        pass_id,
-                        dirty_nodes.all_dirty(pass_id).collect::<Vec<_>>()
-                    );
-                    current_pass.replace(AnyPass::Node(pass));
+                if result.progress {
+                    let height = tree.height(id).unwrap();
+                    dirty.insert(height, id);
                 }
             }
-            return;
         }
     }
-    panic!(
-        "No pass found with all dependancies resolved in {:?}",
-        passes.iter().map(|p| p.pass_id()).collect::<Vec<_>>()
-    );
 }
 
-pub fn resolve_passes<T>(
+pub trait DownwardPass<T>: Pass {
+    fn pass(&self, node: &mut T, parent: Option<&mut T>) -> PassReturn;
+}
+
+fn resolve_downward_pass<T, P: DownwardPass<T> + ?Sized>(
     tree: &mut impl TreeView<T>,
-    starting_nodes: FxHashMap<NodeId, FxHashSet<PassId>>,
-    mut passes: Vec<AnyPass<T>>,
+    pass: &P,
+    mut dirty: DirtyNodes,
+    dirty_states: &DirtyNodeStates,
 ) {
-    assert!(!passes.is_empty());
-    let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
-    let global = Injector::default();
-
-    let core_count = thread::available_parallelism()
-        .map(|c| c.get())
-        .unwrap_or(1);
-    let workers: Vec<Worker<NodeId>> = (0..core_count).map(|_| Worker::new_fifo()).collect();
-    let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
-    let mut shared_view = SharedView::new(tree);
-    let mut resolved_passes: FxHashSet<PassId> = FxHashSet::default();
-    let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
-
-    thread::scope(|s| {
-        {
-            let mut write = current_pass.write();
-            get_pass(
-                &mut passes,
-                &mut resolved_passes,
-                &dirty_nodes,
-                &mut shared_view,
-                &global,
-                &mut *write,
-            );
-        }
-        let global = &global;
-        let stealers = &stealers;
-        let mut thread_handles = Vec::new();
-        let threads_finished = Arc::new((Mutex::new(0), Condvar::new()));
-        for (_, w) in (0..core_count).zip(workers.into_iter()) {
-            let mut shared_view = shared_view.clone();
-            let current_pass = current_pass.clone();
-            let dirty_nodes = dirty_nodes.clone();
-            let threads_finished = threads_finished.clone();
-            thread_handles.push(s.spawn(move || {
-                loop {
-                    let read = current_pass.read();
-                    if let Some(current_pass) = &*read {
-                        let current_pass_id = current_pass.pass_id();
-                        match current_pass {
-                            AnyPass::Upward(_) => {
-                                todo!("Upward passes are single threaded")
-                            }
-                            AnyPass::Node(pass) => {
-                                // Node passes are the easiest to parallelize. We just run the pass on each node.
-                                while let Some(id) = find_task(&w, global, stealers) {
-                                    let node = shared_view.get_mut(id).unwrap();
-                                    if pass.pass(node) {
-                                        for dependant in pass.dependants() {
-                                            dirty_nodes.insert(*dependant, id);
-                                        }
-                                    }
-                                }
-                            }
-                            AnyPass::Downward(pass) => {
-                                // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load.
-                                while let Some(id) = find_task(&w, global, stealers) {
-                                    let (node, parent) = shared_view.node_parent_mut(id).unwrap();
-                                    if pass.pass(node, parent) {
-                                        for id in shared_view.children_ids(id).unwrap() {
-                                            for dependant in pass.dependants() {
-                                                dirty_nodes.insert(*dependant, *id);
-                                            }
-                                            if !dirty_nodes.get(current_pass_id, *id) {
-                                                w.push(*id);
-                                            }
-                                        }
-                                    }
-                                }
-                            }
-                        }
-                        // unblock the rwlock
-                        drop(read);
-                        {
-                            let (count, cvar) = &*threads_finished;
-                            let mut lock = count.lock();
-                            *lock += 1;
-                            if *lock == core_count {
-                                cvar.notify_all();
-                            }
-                        }
-                        // wait for the main thread to pick the next pass
-                        thread::park();
-                    } else {
-                        break;
+    while let Some(id) = dirty.pop_front() {
+        let (node, parent) = tree.node_parent_mut(id).unwrap();
+        let result = pass.pass(node, parent);
+        if result.mark_dirty || result.progress {
+            for id in tree.children_ids(id).unwrap() {
+                if result.mark_dirty {
+                    for dependant in pass.dependants() {
+                        dirty_states.insert(*dependant, *id);
                     }
                 }
-            }));
-        }
-        let wait_for_thread_to_finish = move || {
-            // this will block until all threads are done with the current pass
-            let (count, cvar) = &*threads_finished;
-            let mut count = count.lock();
-            cvar.wait(&mut count);
-        };
-        while !passes.is_empty() {
-            wait_for_thread_to_finish();
-            let mut write = current_pass.write();
-            println!("Threads finished");
-            get_pass(
-                &mut passes,
-                &mut resolved_passes,
-                &dirty_nodes,
-                &mut shared_view,
-                global,
-                &mut *write,
-            );
-            // notify all threads to start the next pass
-            for thread in &thread_handles {
-                thread.thread().unpark();
+                if result.progress {
+                    let height = tree.height(*id).unwrap();
+                    dirty.insert(height, *id);
+                }
             }
         }
-        *current_pass.write() = None;
-        for thread in &thread_handles {
-            thread.thread().unpark();
+    }
+}
+
+pub trait NodePass<T>: Pass {
+    fn pass(&self, node: &mut T) -> bool;
+}
+
+fn resolve_node_pass<T, P: NodePass<T> + ?Sized>(
+    tree: &mut impl TreeView<T>,
+    pass: &P,
+    mut dirty: DirtyNodes,
+    dirty_states: &DirtyNodeStates,
+) {
+    while let Some(id) = dirty.pop_back() {
+        let node = tree.get_mut(id).unwrap();
+        if pass.pass(node) {
+            for dependant in pass.dependants() {
+                dirty_states.insert(*dependant, id);
+            }
         }
-    });
+    }
 }
 
-fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {
-    // Pop a task from the local queue, if not empty.
-    local.pop().or_else(|| {
-        // Otherwise, we need to look for a task elsewhere.
-        std::iter::repeat_with(|| {
-            // Try stealing a batch of tasks from the global queue.
-            global
-                .steal_batch_and_pop(local)
-                // Or try stealing a task from one of the other threads.
-                .or_else(|| stealers.iter().map(|s| s.steal()).collect())
-        })
-        // Loop while no task was stolen and any steal operation needs to be retried.
-        .find(|s| !s.is_retry())
-        // Extract the stolen task, if there is one.
-        .and_then(|s| s.success())
-    })
+pub enum AnyPass<T> {
+    Upward(Box<dyn UpwardPass<T> + Send + Sync>),
+    Downward(Box<dyn DownwardPass<T> + Send + Sync>),
+    Node(Box<dyn NodePass<T> + Send + Sync>),
+}
+
+impl<T> AnyPass<T> {
+    fn pass_id(&self) -> PassId {
+        match self {
+            Self::Upward(pass) => pass.pass_id(),
+            Self::Downward(pass) => pass.pass_id(),
+            Self::Node(pass) => pass.pass_id(),
+        }
+    }
+
+    fn dependancies(&self) -> &'static [PassId] {
+        match self {
+            Self::Upward(pass) => pass.dependancies(),
+            Self::Downward(pass) => pass.dependancies(),
+            Self::Node(pass) => pass.dependancies(),
+        }
+    }
+
+    fn mask(&self) -> MemberMask {
+        match self {
+            Self::Upward(pass) => pass.mask(),
+            Self::Downward(pass) => pass.mask(),
+            Self::Node(pass) => pass.mask(),
+        }
+    }
+
+    fn resolve(
+        &self,
+        tree: &mut impl TreeView<T>,
+        dirty: DirtyNodes,
+        dirty_states: &DirtyNodeStates,
+    ) {
+        match self {
+            Self::Downward(pass) => resolve_downward_pass(tree, pass.as_ref(), dirty, dirty_states),
+            Self::Upward(pass) => resolve_upward_pass(tree, pass.as_ref(), dirty, dirty_states),
+            Self::Node(pass) => resolve_node_pass(tree, pass.as_ref(), dirty, dirty_states),
+        }
+    }
+}
+
+struct RawPointer<T>(*mut T);
+unsafe impl<T> Send for RawPointer<T> {}
+unsafe impl<T> Sync for RawPointer<T> {}
+
+fn resolve_passes<T, Tr: TreeView<T>>(
+    tree: &mut Tr,
+    dirty_nodes: DirtyNodeStates,
+    mut passes: Vec<AnyPass<T>>,
+) {
+    let dirty_states = Arc::new(dirty_nodes);
+    let mut resolved_passes: FxHashSet<PassId> = FxHashSet::default();
+    let mut resolving = Vec::new();
+    while !passes.is_empty() {
+        let mut currently_borrowed = MemberMask::default();
+        std::thread::scope(|s| {
+            let mut i = 0;
+            while i < passes.len() {
+                let pass = &passes[i];
+                let pass_id = pass.pass_id();
+                let pass_mask = pass.mask();
+                if pass
+                    .dependancies()
+                    .iter()
+                    .all(|d| resolved_passes.contains(d))
+                    && !pass_mask.overlaps(currently_borrowed)
+                {
+                    let pass = passes.remove(i);
+                    resolving.push(pass_id);
+                    currently_borrowed |= pass_mask;
+                    let tree_mut = tree as *mut _;
+                    let raw_ptr = RawPointer(tree_mut);
+                    let dirty_states = dirty_states.clone();
+                    s.spawn(move || unsafe {
+                        // let tree_mut: &mut Tr = &mut *raw_ptr.0;
+                        let raw = raw_ptr;
+                        let tree_mut: &mut Tr = &mut *raw.0;
+                        let mut dirty = DirtyNodes::default();
+                        dirty_states.all_dirty(pass_id, &mut dirty, tree_mut);
+                        pass.resolve(tree_mut, dirty, &dirty_states);
+                    });
+                } else {
+                    i += 1;
+                }
+            }
+            // all passes are resolved at the end of the scope
+        });
+        resolved_passes.extend(resolving.iter().copied());
+        resolving.clear()
+    }
 }
 
 #[test]
@@ -371,8 +374,7 @@ fn node_pass() {
     println!("{:#?}", tree);
 
     struct AddPass;
-
-    impl NodePass<i32> for AddPass {
+    impl Pass for AddPass {
         fn pass_id(&self) -> PassId {
             PassId(0)
         }
@@ -385,6 +387,12 @@ fn node_pass() {
             &[]
         }
 
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+
+    impl NodePass<i32> for AddPass {
         fn pass(&self, node: &mut i32) -> bool {
             *node += 1;
             true
@@ -392,8 +400,8 @@ fn node_pass() {
     }
 
     let passes = vec![AnyPass::Node(Box::new(AddPass))];
-    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
-    dirty_nodes.insert(tree.root(), [PassId(0)].into_iter().collect());
+    let mut dirty_nodes: DirtyNodeStates = DirtyNodeStates::default();
+    dirty_nodes.insert(PassId(0), tree.root());
     resolve_passes(&mut tree, dirty_nodes, passes);
 
     assert_eq!(tree.get(tree.root()).unwrap(), &1);
@@ -414,8 +422,7 @@ fn dependant_node_pass() {
     tree.add_child(child2, grandchild2);
 
     struct AddPass;
-
-    impl NodePass<i32> for AddPass {
+    impl Pass for AddPass {
         fn pass_id(&self) -> PassId {
             PassId(0)
         }
@@ -428,6 +435,12 @@ fn dependant_node_pass() {
             &[]
         }
 
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+
+    impl NodePass<i32> for AddPass {
         fn pass(&self, node: &mut i32) -> bool {
             println!("AddPass: {}", node);
             *node += 1;
@@ -437,7 +450,7 @@ fn dependant_node_pass() {
 
     struct SubtractPass;
 
-    impl NodePass<i32> for SubtractPass {
+    impl Pass for SubtractPass {
         fn pass_id(&self) -> PassId {
             PassId(1)
         }
@@ -450,6 +463,11 @@ fn dependant_node_pass() {
             &[PassId(0)]
         }
 
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+    impl NodePass<i32> for SubtractPass {
         fn pass(&self, node: &mut i32) -> bool {
             println!("SubtractPass: {}", node);
             *node -= 1;
@@ -461,8 +479,8 @@ fn dependant_node_pass() {
         AnyPass::Node(Box::new(AddPass)),
         AnyPass::Node(Box::new(SubtractPass)),
     ];
-    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
-    dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
+    let mut dirty_nodes: DirtyNodeStates = DirtyNodeStates::default();
+    dirty_nodes.insert(PassId(1), tree.root());
     resolve_passes(&mut tree, dirty_nodes, passes);
 
     assert_eq!(*tree.get(tree.root()).unwrap(), 0);
@@ -484,7 +502,7 @@ fn down_pass() {
 
     struct AddPass;
 
-    impl DownwardPass<i32> for AddPass {
+    impl Pass for AddPass {
         fn pass_id(&self) -> PassId {
             PassId(0)
         }
@@ -497,17 +515,25 @@ fn down_pass() {
             &[]
         }
 
-        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+    impl DownwardPass<i32> for AddPass {
+        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> PassReturn {
             if let Some(parent) = parent {
                 *node += *parent;
             }
-            true
+            PassReturn {
+                progress: true,
+                mark_dirty: true,
+            }
         }
     }
 
     let passes = vec![AnyPass::Downward(Box::new(AddPass))];
-    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
-    dirty_nodes.insert(tree.root(), [PassId(0)].into_iter().collect());
+    let dirty_nodes: DirtyNodeStates = DirtyNodeStates::default();
+    dirty_nodes.insert(PassId(0), tree.root());
     resolve_passes(&mut tree, dirty_nodes, passes);
 
     assert_eq!(tree.get(tree.root()).unwrap(), &1);
@@ -537,8 +563,7 @@ fn dependant_down_pass() {
     tree.add_child(child2, grandchild2);
 
     struct AddPass;
-
-    impl DownwardPass<i32> for AddPass {
+    impl Pass for AddPass {
         fn pass_id(&self) -> PassId {
             PassId(0)
         }
@@ -551,20 +576,27 @@ fn dependant_down_pass() {
             &[]
         }
 
-        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+    impl DownwardPass<i32> for AddPass {
+        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> PassReturn {
             if let Some(parent) = parent {
                 println!("AddPass: {} -> {}", node, *node + *parent);
                 *node += *parent;
             } else {
                 println!("AddPass: {}", node);
             }
-            true
+            PassReturn {
+                progress: true,
+                mark_dirty: true,
+            }
         }
     }
 
     struct SubtractPass;
-
-    impl DownwardPass<i32> for SubtractPass {
+    impl Pass for SubtractPass {
         fn pass_id(&self) -> PassId {
             PassId(1)
         }
@@ -577,14 +609,22 @@ fn dependant_down_pass() {
             &[PassId(0)]
         }
 
-        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+    impl DownwardPass<i32> for SubtractPass {
+        fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> PassReturn {
             if let Some(parent) = parent {
                 println!("SubtractPass: {} -> {}", node, *node - *parent);
                 *node -= *parent;
             } else {
                 println!("SubtractPass: {}", node);
             }
-            true
+            PassReturn {
+                progress: true,
+                mark_dirty: true,
+            }
         }
     }
 
@@ -592,8 +632,8 @@ fn dependant_down_pass() {
         AnyPass::Downward(Box::new(AddPass)),
         AnyPass::Downward(Box::new(SubtractPass)),
     ];
-    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
-    dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
+    let mut dirty_nodes: DirtyNodeStates = DirtyNodeStates::default();
+    dirty_nodes.insert(PassId(1), tree.root());
     resolve_passes(&mut tree, dirty_nodes, passes);
 
     // Tree before:
@@ -648,8 +688,7 @@ fn up_pass() {
     tree.add_child(child2, grandchild2);
 
     struct AddPass;
-
-    impl UpwardPass<i32> for AddPass {
+    impl Pass for AddPass {
         fn pass_id(&self) -> PassId {
             PassId(0)
         }
@@ -662,20 +701,28 @@ fn up_pass() {
             &[]
         }
 
+        fn mask(&self) -> MemberMask {
+            MemberMask(1)
+        }
+    }
+    impl UpwardPass<i32> for AddPass {
         fn pass<'a>(
             &self,
             node: &mut i32,
             children: &mut dyn Iterator<Item = &'a mut i32>,
-        ) -> bool {
+        ) -> PassReturn {
             *node += children.map(|i| *i).sum::<i32>();
-            true
+            PassReturn {
+                progress: true,
+                mark_dirty: true,
+            }
         }
     }
 
     let passes = vec![AnyPass::Upward(Box::new(AddPass))];
-    let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
-    dirty_nodes.insert(grandchild1, [PassId(0)].into_iter().collect());
-    dirty_nodes.insert(grandchild2, [PassId(0)].into_iter().collect());
+    let mut dirty_nodes: DirtyNodeStates = DirtyNodeStates::default();
+    dirty_nodes.insert(PassId(0), grandchild1);
+    dirty_nodes.insert(PassId(0), grandchild2);
     resolve_passes(&mut tree, dirty_nodes, passes);
 
     assert_eq!(tree.get(tree.root()).unwrap(), &2);