|
@@ -1,15 +1,14 @@
|
|
use crossbeam_deque::{Injector, Stealer, Worker};
|
|
use crossbeam_deque::{Injector, Stealer, Worker};
|
|
-use parking_lot::RwLock;
|
|
|
|
|
|
+use parking_lot::{Condvar, Mutex, RwLock};
|
|
use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
|
|
use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
|
|
use std::hash::BuildHasherDefault;
|
|
use std::hash::BuildHasherDefault;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
use std::thread;
|
|
-use std::time::Duration;
|
|
|
|
|
|
|
|
use crate::tree::{NodeId, SharedView, TreeView};
|
|
use crate::tree::{NodeId, SharedView, TreeView};
|
|
|
|
|
|
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
|
|
|
|
|
|
+#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
|
|
pub struct PassId(u64);
|
|
pub struct PassId(u64);
|
|
|
|
|
|
pub trait UpwardPass<T> {
|
|
pub trait UpwardPass<T> {
|
|
@@ -124,7 +123,7 @@ fn get_pass<T, Tr: TreeView<T>>(
|
|
dirty_nodes: &DirtyNodeStates,
|
|
dirty_nodes: &DirtyNodeStates,
|
|
shared_view: &mut SharedView<T, Tr>,
|
|
shared_view: &mut SharedView<T, Tr>,
|
|
global: &Injector<NodeId>,
|
|
global: &Injector<NodeId>,
|
|
- current_pass: &RwLock<Option<AnyPass<T>>>,
|
|
|
|
|
|
+ current_pass: &mut Option<AnyPass<T>>,
|
|
) {
|
|
) {
|
|
for i in 0..passes.len() {
|
|
for i in 0..passes.len() {
|
|
if passes[i]
|
|
if passes[i]
|
|
@@ -134,6 +133,7 @@ fn get_pass<T, Tr: TreeView<T>>(
|
|
{
|
|
{
|
|
let pass = passes.remove(i);
|
|
let pass = passes.remove(i);
|
|
let pass_id = pass.pass_id();
|
|
let pass_id = pass.pass_id();
|
|
|
|
+ println!("Running pass {:?}", pass_id);
|
|
resolved_passes.insert(pass_id);
|
|
resolved_passes.insert(pass_id);
|
|
match pass {
|
|
match pass {
|
|
AnyPass::Upward(pass) => {
|
|
AnyPass::Upward(pass) => {
|
|
@@ -164,13 +164,21 @@ fn get_pass<T, Tr: TreeView<T>>(
|
|
for node in dirty_nodes.all_dirty(pass_id) {
|
|
for node in dirty_nodes.all_dirty(pass_id) {
|
|
global.push(node);
|
|
global.push(node);
|
|
}
|
|
}
|
|
- current_pass.write().replace(pass);
|
|
|
|
|
|
+ println!(
|
|
|
|
+ "Task: {:?} {:?}",
|
|
|
|
+ pass_id,
|
|
|
|
+ dirty_nodes.all_dirty(pass_id).collect::<Vec<_>>()
|
|
|
|
+ );
|
|
|
|
+ current_pass.replace(pass);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
- break;
|
|
|
|
|
|
+ return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ panic!(
|
|
|
|
+ "No pass found with all dependancies resolved in {:?}",
|
|
|
|
+ passes.iter().map(|p| p.pass_id()).collect::<Vec<_>>()
|
|
|
|
+ );
|
|
}
|
|
}
|
|
|
|
|
|
pub fn resolve_passes<T>(
|
|
pub fn resolve_passes<T>(
|
|
@@ -178,6 +186,7 @@ pub fn resolve_passes<T>(
|
|
starting_nodes: FxHashMap<NodeId, FxHashSet<PassId>>,
|
|
starting_nodes: FxHashMap<NodeId, FxHashSet<PassId>>,
|
|
mut passes: Vec<AnyPass<T>>,
|
|
mut passes: Vec<AnyPass<T>>,
|
|
) {
|
|
) {
|
|
|
|
+ assert!(!passes.is_empty());
|
|
let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
|
|
let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
|
|
let global = Injector::default();
|
|
let global = Injector::default();
|
|
|
|
|
|
@@ -191,69 +200,105 @@ pub fn resolve_passes<T>(
|
|
let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
|
|
let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
|
|
|
|
|
|
thread::scope(|s| {
|
|
thread::scope(|s| {
|
|
- get_pass(
|
|
|
|
- &mut passes,
|
|
|
|
- &mut resolved_passes,
|
|
|
|
- &dirty_nodes,
|
|
|
|
- &mut shared_view,
|
|
|
|
- &global,
|
|
|
|
- ¤t_pass,
|
|
|
|
- );
|
|
|
|
|
|
+ {
|
|
|
|
+ let mut write = current_pass.write();
|
|
|
|
+ get_pass(
|
|
|
|
+ &mut passes,
|
|
|
|
+ &mut resolved_passes,
|
|
|
|
+ &dirty_nodes,
|
|
|
|
+ &mut shared_view,
|
|
|
|
+ &global,
|
|
|
|
+ &mut *write,
|
|
|
|
+ );
|
|
|
|
+ }
|
|
let global = &global;
|
|
let global = &global;
|
|
let stealers = &stealers;
|
|
let stealers = &stealers;
|
|
|
|
+ let mut thread_handles = Vec::new();
|
|
|
|
+ let threads_finished = Arc::new((Mutex::new(0), Condvar::new()));
|
|
for (_, w) in (0..core_count).zip(workers.into_iter()) {
|
|
for (_, w) in (0..core_count).zip(workers.into_iter()) {
|
|
let mut shared_view = shared_view.clone();
|
|
let mut shared_view = shared_view.clone();
|
|
let current_pass = current_pass.clone();
|
|
let current_pass = current_pass.clone();
|
|
let dirty_nodes = dirty_nodes.clone();
|
|
let dirty_nodes = dirty_nodes.clone();
|
|
- s.spawn(move || {
|
|
|
|
- while let Some(current_pass) = &*current_pass.read() {
|
|
|
|
- match current_pass {
|
|
|
|
- AnyPass::Upward(_) => {
|
|
|
|
- todo!("Upward passes are single threaded")
|
|
|
|
- }
|
|
|
|
- AnyPass::Node(pass) => {
|
|
|
|
- // Node passes are the easiest to parallelize. We just run the pass on each node.
|
|
|
|
- while let Some(id) = find_task(&w, global, stealers) {
|
|
|
|
- let node = shared_view.get_mut(id).unwrap();
|
|
|
|
- if pass.pass(node) {
|
|
|
|
- for dependant in pass.dependants() {
|
|
|
|
- dirty_nodes.insert(*dependant, id);
|
|
|
|
|
|
+ let threads_finished = threads_finished.clone();
|
|
|
|
+ thread_handles.push(s.spawn(move || {
|
|
|
|
+ loop {
|
|
|
|
+ let read = current_pass.read();
|
|
|
|
+ if let Some(current_pass) = &*read {
|
|
|
|
+ match current_pass {
|
|
|
|
+ AnyPass::Upward(_) => {
|
|
|
|
+ todo!("Upward passes are single threaded")
|
|
|
|
+ }
|
|
|
|
+ AnyPass::Node(pass) => {
|
|
|
|
+ // Node passes are the easiest to parallelize. We just run the pass on each node.
|
|
|
|
+ while let Some(id) = find_task(&w, global, stealers) {
|
|
|
|
+ let node = shared_view.get_mut(id).unwrap();
|
|
|
|
+ if pass.pass(node) {
|
|
|
|
+ for dependant in pass.dependants() {
|
|
|
|
+ dirty_nodes.insert(*dependant, id);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- }
|
|
|
|
- AnyPass::Downward(pass) => {
|
|
|
|
- // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load.
|
|
|
|
- while let Some(id) = find_task(&w, global, stealers) {
|
|
|
|
- let (node, parent) = shared_view.node_parent_mut(id).unwrap();
|
|
|
|
- if pass.pass(node, parent) {
|
|
|
|
- for id in shared_view.children_ids(id).unwrap() {
|
|
|
|
- for dependant in pass.dependants() {
|
|
|
|
- dirty_nodes.insert(*dependant, *id);
|
|
|
|
|
|
+ AnyPass::Downward(pass) => {
|
|
|
|
+ // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load.
|
|
|
|
+ while let Some(id) = find_task(&w, global, stealers) {
|
|
|
|
+ let (node, parent) = shared_view.node_parent_mut(id).unwrap();
|
|
|
|
+ if pass.pass(node, parent) {
|
|
|
|
+ for id in shared_view.children_ids(id).unwrap() {
|
|
|
|
+ for dependant in pass.dependants() {
|
|
|
|
+ dirty_nodes.insert(*dependant, *id);
|
|
|
|
+ }
|
|
|
|
+ w.push(*id);
|
|
}
|
|
}
|
|
- w.push(*id);
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+ // unblock the rwlock
|
|
|
|
+ drop(read);
|
|
|
|
+ {
|
|
|
|
+ let (count, cvar) = &*threads_finished;
|
|
|
|
+ let mut lock = count.lock();
|
|
|
|
+ *lock += 1;
|
|
|
|
+ if *lock == core_count {
|
|
|
|
+ cvar.notify_all();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // wait for the main thread to pick the next pass
|
|
|
|
+ thread::park();
|
|
|
|
+ } else {
|
|
|
|
+ break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
- });
|
|
|
|
|
|
+ }));
|
|
}
|
|
}
|
|
|
|
+ let wait_for_thread_to_finish = move || {
|
|
|
|
+ // this will block until all threads are done with the current pass
|
|
|
|
+ let (count, cvar) = &*threads_finished;
|
|
|
|
+ let mut count = count.lock();
|
|
|
|
+ cvar.wait(&mut count);
|
|
|
|
+ };
|
|
while !passes.is_empty() {
|
|
while !passes.is_empty() {
|
|
- while !stealers.iter().all(|s| s.is_empty()) {
|
|
|
|
- std::thread::sleep(Duration::from_millis(50));
|
|
|
|
- }
|
|
|
|
|
|
+ wait_for_thread_to_finish();
|
|
|
|
+ let mut write = current_pass.write();
|
|
|
|
+ println!("Threads finished");
|
|
get_pass(
|
|
get_pass(
|
|
&mut passes,
|
|
&mut passes,
|
|
&mut resolved_passes,
|
|
&mut resolved_passes,
|
|
&dirty_nodes,
|
|
&dirty_nodes,
|
|
&mut shared_view,
|
|
&mut shared_view,
|
|
global,
|
|
global,
|
|
- ¤t_pass,
|
|
|
|
|
|
+ &mut *write,
|
|
);
|
|
);
|
|
|
|
+ // notify all threads to start the next pass
|
|
|
|
+ for thread in &thread_handles {
|
|
|
|
+ thread.thread().unpark();
|
|
|
|
+ }
|
|
}
|
|
}
|
|
*current_pass.write() = None;
|
|
*current_pass.write() = None;
|
|
|
|
+ for thread in &thread_handles {
|
|
|
|
+ thread.thread().unpark();
|
|
|
|
+ }
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
@@ -318,6 +363,75 @@ fn node_pass() {
|
|
assert_eq!(tree.get(tree.root()).unwrap(), &1);
|
|
assert_eq!(tree.get(tree.root()).unwrap(), &1);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+#[test]
|
|
|
|
+fn dependant_node_pass() {
|
|
|
|
+ use crate::tree::{Tree, TreeLike};
|
|
|
|
+ let mut tree = Tree::new(0);
|
|
|
|
+ let parent = tree.root();
|
|
|
|
+ let child1 = tree.create_node(1);
|
|
|
|
+ tree.add_child(parent, child1);
|
|
|
|
+ let grandchild1 = tree.create_node(3);
|
|
|
|
+ tree.add_child(child1, grandchild1);
|
|
|
|
+ let child2 = tree.create_node(2);
|
|
|
|
+ tree.add_child(parent, child2);
|
|
|
|
+ let grandchild2 = tree.create_node(4);
|
|
|
|
+ tree.add_child(child2, grandchild2);
|
|
|
|
+
|
|
|
|
+ struct AddPass;
|
|
|
|
+
|
|
|
|
+ impl NodePass<i32> for AddPass {
|
|
|
|
+ fn pass_id(&self) -> PassId {
|
|
|
|
+ PassId(0)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependancies(&self) -> &'static [PassId] {
|
|
|
|
+ &[PassId(1)]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependants(&self) -> &'static [PassId] {
|
|
|
|
+ &[]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn pass(&self, node: &mut i32) -> bool {
|
|
|
|
+ println!("AddPass: {}", node);
|
|
|
|
+ *node += 1;
|
|
|
|
+ true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ struct SubtractPass;
|
|
|
|
+
|
|
|
|
+ impl NodePass<i32> for SubtractPass {
|
|
|
|
+ fn pass_id(&self) -> PassId {
|
|
|
|
+ PassId(1)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependancies(&self) -> &'static [PassId] {
|
|
|
|
+ &[]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependants(&self) -> &'static [PassId] {
|
|
|
|
+ &[PassId(0)]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn pass(&self, node: &mut i32) -> bool {
|
|
|
|
+ println!("SubtractPass: {}", node);
|
|
|
|
+ *node -= 1;
|
|
|
|
+ true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ let passes = vec![
|
|
|
|
+ AnyPass::Node(Box::new(AddPass)),
|
|
|
|
+ AnyPass::Node(Box::new(SubtractPass)),
|
|
|
|
+ ];
|
|
|
|
+ let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
|
|
|
|
+ dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
|
|
|
|
+ resolve_passes(&mut tree, dirty_nodes, passes);
|
|
|
|
+
|
|
|
|
+ assert_eq!(*tree.get(tree.root()).unwrap(), 0);
|
|
|
|
+}
|
|
|
|
+
|
|
#[test]
|
|
#[test]
|
|
fn down_pass() {
|
|
fn down_pass() {
|
|
use crate::tree::{Tree, TreeLike};
|
|
use crate::tree::{Tree, TreeLike};
|
|
@@ -367,6 +481,99 @@ fn down_pass() {
|
|
assert_eq!(tree.get(grandchild2).unwrap(), &3);
|
|
assert_eq!(tree.get(grandchild2).unwrap(), &3);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+#[test]
|
|
|
|
+fn dependant_down_pass() {
|
|
|
|
+ use crate::tree::{Tree, TreeLike};
|
|
|
|
+ let mut tree = Tree::new(1);
|
|
|
|
+ let parent = tree.root();
|
|
|
|
+ let child1 = tree.create_node(1);
|
|
|
|
+ tree.add_child(parent, child1);
|
|
|
|
+ let grandchild1 = tree.create_node(1);
|
|
|
|
+ tree.add_child(child1, grandchild1);
|
|
|
|
+ let child2 = tree.create_node(1);
|
|
|
|
+ tree.add_child(parent, child2);
|
|
|
|
+ let grandchild2 = tree.create_node(1);
|
|
|
|
+ tree.add_child(child2, grandchild2);
|
|
|
|
+
|
|
|
|
+ struct AddPass;
|
|
|
|
+
|
|
|
|
+ impl DownwardPass<i32> for AddPass {
|
|
|
|
+ fn pass_id(&self) -> PassId {
|
|
|
|
+ PassId(0)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependancies(&self) -> &'static [PassId] {
|
|
|
|
+ &[PassId(1)]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependants(&self) -> &'static [PassId] {
|
|
|
|
+ &[]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
|
|
|
|
+ if let Some(parent) = parent {
|
|
|
|
+ *node += *parent;
|
|
|
|
+ }
|
|
|
|
+ true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ struct SubtractPass;
|
|
|
|
+
|
|
|
|
+ impl DownwardPass<i32> for SubtractPass {
|
|
|
|
+ fn pass_id(&self) -> PassId {
|
|
|
|
+ PassId(1)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependancies(&self) -> &'static [PassId] {
|
|
|
|
+ &[]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn dependants(&self) -> &'static [PassId] {
|
|
|
|
+ &[PassId(0)]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool {
|
|
|
|
+ if let Some(parent) = parent {
|
|
|
|
+ *node -= *parent;
|
|
|
|
+ }
|
|
|
|
+ true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ let passes = vec![
|
|
|
|
+ AnyPass::Downward(Box::new(AddPass)),
|
|
|
|
+ AnyPass::Downward(Box::new(SubtractPass)),
|
|
|
|
+ ];
|
|
|
|
+ let mut dirty_nodes: FxHashMap<NodeId, FxHashSet<PassId>> = FxHashMap::default();
|
|
|
|
+ dirty_nodes.insert(tree.root(), [PassId(1)].into_iter().collect());
|
|
|
|
+ resolve_passes(&mut tree, dirty_nodes, passes);
|
|
|
|
+
|
|
|
|
+ // Tree before:
|
|
|
|
+ // 1=\
|
|
|
|
+ // 1=\
|
|
|
|
+ // 1
|
|
|
|
+ // 1=\
|
|
|
|
+ // 1
|
|
|
|
+ // Tree after subtract:
|
|
|
|
+ // 1=\
|
|
|
|
+ // 0=\
|
|
|
|
+ // 1
|
|
|
|
+ // 0=\
|
|
|
|
+ // 1
|
|
|
|
+ // Tree after add:
|
|
|
|
+ // 1=\
|
|
|
|
+ // 1=\
|
|
|
|
+ // 2
|
|
|
|
+ // 1=\
|
|
|
|
+ // 2
|
|
|
|
+ assert_eq!(tree.get(tree.root()).unwrap(), &1);
|
|
|
|
+ assert_eq!(tree.get(child1).unwrap(), &1);
|
|
|
|
+ assert_eq!(tree.get(grandchild1).unwrap(), &2);
|
|
|
|
+ assert_eq!(tree.get(child2).unwrap(), &1);
|
|
|
|
+ assert_eq!(tree.get(grandchild2).unwrap(), &2);
|
|
|
|
+}
|
|
|
|
+
|
|
#[test]
|
|
#[test]
|
|
fn up_pass() {
|
|
fn up_pass() {
|
|
use crate::tree::{Tree, TreeLike};
|
|
use crate::tree::{Tree, TreeLike};
|