ソースを参照

add more code for parrellel passes

Evan Almloff 2 年 前
コミット
c7eeeef68e
2 ファイル変更213 行追加26 行削除
  1. 1 0
      packages/native-core/Cargo.toml
  2. 212 26
      packages/native-core/src/tree.rs

+ 1 - 0
packages/native-core/Cargo.toml

@@ -22,6 +22,7 @@ anymap = "0.12.1"
 slab = "0.4"
 parking_lot = "0.12.1"
 crossbeam-deque = "0.8.2"
+dashmap = "5.4.0"
 
 [dev-dependencies]
 rand = "0.8.5"

+ 212 - 26
packages/native-core/src/tree.rs

@@ -1,7 +1,14 @@
 use core::panic;
+use std::hash::BuildHasherDefault;
+use std::sync::atomic::{AtomicU64, Ordering};
+
 use crossbeam_deque::{Injector, Stealer, Worker};
+use dashmap::DashSet;
+use dioxus_core::ScopeId;
+use dioxus_html::u;
 use parking_lot::lock_api::RawMutex as _;
-use parking_lot::{RawMutex, RwLock};
+use parking_lot::{Mutex, RawMutex, RwLock, RwLockWriteGuard};
+use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
 use slab::Slab;
 use std::cell::UnsafeCell;
 use std::collections::VecDeque;
@@ -822,39 +829,218 @@ fn traverse_breadth_first() {
     });
 }
 
+enum PassDirection {
+    Up,
+    Down,
+    Node,
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+struct PassId(u64);
+
 trait UpwardPass<T> {
-    fn upward_pass(&mut self, node: &mut T, parent: Option<&mut T>) -> bool;
+    fn pass_id(&self) -> PassId;
+    fn dependancies(&self) -> &'static [PassId];
+    fn dependants(&self) -> &'static [PassId];
+    fn upward_pass(&self, node: &mut T, parent: Option<&mut T>) -> bool;
+}
+
+trait DownwardPass<T> {
+    fn pass_id(&self) -> PassId;
+    fn dependancies(&self) -> &'static [PassId];
+    fn dependants(&self) -> &'static [PassId];
+    fn downward_pass<'a>(
+        &self,
+        node: &mut T,
+        children: &mut dyn Iterator<Item = &'a mut T>,
+    ) -> bool;
+}
+
+trait NodePass<T> {
+    fn pass_id(&self) -> PassId;
+    fn dependancies(&self) -> &'static [PassId];
+    fn dependants(&self) -> &'static [PassId];
+    fn node_pass(&self, node: &mut T) -> bool;
+}
+
+enum AnyPass<T> {
+    Upward(Box<dyn UpwardPass<T> + Send + Sync>),
+    Downward(Box<dyn DownwardPass<T> + Send + Sync>),
+    Node(Box<dyn NodePass<T> + Send + Sync>),
+}
+
+impl<T> AnyPass<T> {
+    fn pass_id(&self) -> PassId {
+        match self {
+            Self::Upward(pass) => pass.pass_id(),
+            Self::Downward(pass) => pass.pass_id(),
+            Self::Node(pass) => pass.pass_id(),
+        }
+    }
+
+    fn dependancies(&self) -> &'static [PassId] {
+        match self {
+            Self::Upward(pass) => pass.dependancies(),
+            Self::Downward(pass) => pass.dependancies(),
+            Self::Node(pass) => pass.dependancies(),
+        }
+    }
 
-    fn resolve_pass(&mut self, tree: &mut impl TreeView<T>, starting_nodes: &[NodeId]) {
-        let global = Injector::default();
-        for node in starting_nodes {
-            global.push(*node);
+    fn dependants(&self) -> &'static [PassId] {
+        match self {
+            Self::Upward(pass) => pass.dependants(),
+            Self::Downward(pass) => pass.dependants(),
+            Self::Node(pass) => pass.dependants(),
         }
+    }
+}
+
+type FxDashSet<T> = dashmap::DashSet<T, BuildHasherDefault<FxHasher>>;
+type FxDashMap<K, V> = dashmap::DashMap<K, V, BuildHasherDefault<FxHasher>>;
+
+#[derive(Default)]
+struct DirtyNodeStates {
+    dirty: FxDashMap<NodeId, Vec<AtomicU64>>,
+}
 
-        let core_count = thread::available_parallelism()
-            .map(|c| c.get())
-            .unwrap_or(1);
-        let workers: Vec<_> = (0..core_count).map(|_| Worker::new_fifo()).collect();
-        let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
-        let shared_view = SharedView::new(tree);
-        thread::scope(|s| {
-            let global = &global;
-            let stealers = &stealers;
-            for (_, w) in (0..core_count).zip(workers.into_iter()) {
-                let mut shared_view = shared_view.clone();
-                s.spawn(move || {
-                    while let Some(id) = find_task(&w, &global, &stealers) {
-                        let (node, parent) = shared_view.node_parent_mut(id).unwrap();
-                        if self.upward_pass(node, parent) {
-                            if let Some(id) = shared_view.parent_id(id) {
-                                w.push(id);
+impl DirtyNodeStates {
+    fn new(starting_nodes: FxHashMap<NodeId, FxDashSet<PassId>>) -> Self {
+        let mut this = Self::default();
+        for (node, nodes) in starting_nodes {
+            for pass_id in nodes {
+                this.insert(pass_id, node);
+            }
+        }
+        this
+    }
+
+    fn insert(&self, pass_id: PassId, node_id: NodeId) {
+        let pass_id = pass_id.0;
+        let index = pass_id / 64;
+        let bit = pass_id % 64;
+        let encoded = 1 << bit;
+        if let Some(dirty) = self.dirty.get(&node_id) {
+            if let Some(atomic) = dirty.get(index as usize) {
+                atomic.fetch_or(encoded, Ordering::Relaxed);
+            } else {
+                drop(dirty);
+                let mut write = self.dirty.get_mut(&node_id).unwrap();
+                write.resize_with(index as usize + 1, || AtomicU64::new(0));
+                write[index as usize].fetch_or(encoded, Ordering::Relaxed);
+            }
+        } else {
+            self.dirty.insert(node_id, vec![AtomicU64::new(encoded)]);
+        }
+    }
+
+    fn all_dirty(&self, pass_id: PassId) -> impl Iterator<Item = NodeId> + '_ {
+        let pass_id = pass_id.0;
+        let index = pass_id / 64;
+        let bit = pass_id % 64;
+        let encoded = 1 << bit;
+        self.dirty.iter().filter_map(move |entry| {
+            let node_id = entry.key();
+            let dirty = entry.value();
+            if let Some(atomic) = dirty.get(index as usize) {
+                if atomic.load(Ordering::Relaxed) & encoded != 0 {
+                    Some(*node_id)
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        })
+    }
+}
+
+fn resolve_passes<T>(
+    tree: &mut impl TreeView<T>,
+    starting_nodes: FxHashMap<NodeId, FxDashSet<PassId>>,
+    mut passes: Vec<AnyPass<T>>,
+) {
+    let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
+    let global = Injector::default();
+
+    let core_count = thread::available_parallelism()
+        .map(|c| c.get())
+        .unwrap_or(1);
+    let workers: Vec<Worker<NodeId>> = (0..core_count).map(|_| Worker::new_fifo()).collect();
+    let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
+    let shared_view = SharedView::new(tree);
+    let mut resolved_passes: FxHashSet<PassId> = FxHashSet::default();
+    let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
+    thread::scope(|s| {
+        let global = &global;
+        let stealers = &stealers;
+        for (_, w) in (0..core_count).zip(workers.into_iter()) {
+            let mut shared_view = shared_view.clone();
+            let current_pass = current_pass.clone();
+            let dirty_nodes = dirty_nodes.clone();
+            s.spawn(move || {
+                while let Some(current_pass) = &*current_pass.read() {
+                    match current_pass {
+                        AnyPass::Upward(pass) => {
+                            while let Some(id) = find_task(&w, global, stealers) {
+                                let (node, parent) = shared_view.node_parent_mut(id).unwrap();
+                                if pass.upward_pass(node, parent) {
+                                    if let Some(id) = shared_view.parent_id(id) {
+                                        for dependant in pass.dependants() {
+                                            dirty_nodes.insert(*dependant, id);
+                                        }
+                                        w.push(id);
+                                    }
+                                }
                             }
                         }
+                        AnyPass::Downward(pass) => {
+                            while let Some(id) = find_task(&w, global, stealers) {
+                                let (node, mut children) =
+                                    shared_view.parent_child_mut(id).unwrap();
+                                if pass.downward_pass(node, &mut children) {
+                                    drop(children);
+                                    for id in shared_view.children_ids(id).unwrap() {
+                                        for dependant in pass.dependants() {
+                                            dirty_nodes.insert(*dependant, *id);
+                                        }
+                                        w.push(*id);
+                                    }
+                                }
+                            }
+                        }
+                        AnyPass::Node(pass) => {
+                            while let Some(id) = find_task(&w, global, stealers) {
+                                let node = shared_view.get_mut(id).unwrap();
+                                if pass.node_pass(node) {
+                                    for dependant in pass.dependants() {
+                                        dirty_nodes.insert(*dependant, id);
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            });
+        }
+        while !passes.is_empty() {
+            for i in 0..passes.len() {
+                if passes[i]
+                    .dependancies()
+                    .iter()
+                    .all(|id| resolved_passes.contains(id))
+                {
+                    let pass = passes.remove(i);
+                    let pass_id = pass.pass_id();
+                    for node in dirty_nodes.all_dirty(pass_id) {
+                        global.push(node);
                     }
-                });
+                    resolved_passes.insert(pass_id);
+                    break;
+                }
             }
-        });
-    }
+        }
+        *current_pass.write() = None;
+    });
 }
 
 fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {