tree.rs 20 KB


  1. use std::collections::VecDeque;
  2. use std::marker::PhantomData;
  3. #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, PartialOrd, Ord)]
  4. pub struct NodeId(pub usize);
  5. #[derive(PartialEq, Eq, Clone, Debug)]
  6. pub struct Node<T> {
  7. value: T,
  8. parent: Option<NodeId>,
  9. children: Vec<NodeId>,
  10. height: u16,
  11. }
  12. #[derive(Debug)]
  13. pub struct Tree<T> {
  14. nodes: Slab<Node<T>>,
  15. root: NodeId,
  16. }
  17. impl<T> Tree<T> {
  18. fn try_remove(&mut self, id: NodeId) -> Option<Node<T>> {
  19. self.nodes.try_remove(id.0).map(|node| {
  20. if let Some(parent) = node.parent {
  21. self.nodes
  22. .get_mut(parent.0)
  23. .unwrap()
  24. .children
  25. .retain(|child| child != &id);
  26. }
  27. for child in &node.children {
  28. self.remove_recursive(*child);
  29. }
  30. node
  31. })
  32. }
  33. fn remove_recursive(&mut self, node: NodeId) {
  34. let node = self.nodes.remove(node.0);
  35. for child in node.children {
  36. self.remove_recursive(child);
  37. }
  38. }
  39. fn set_height(&mut self, node: NodeId, height: u16) {
  40. let self_mut = self as *mut Self;
  41. let node = self.nodes.get_mut(node.0).unwrap();
  42. node.height = height;
  43. unsafe {
  44. // Safety: No node has itself as a child
  45. for child in &node.children {
  46. (*self_mut).set_height(*child, height + 1);
  47. }
  48. }
  49. }
  50. }
  51. pub trait TreeView<T>: Sized {
  52. type Iterator<'a>: Iterator<Item = &'a T>
  53. where
  54. T: 'a,
  55. Self: 'a;
  56. type IteratorMut<'a>: Iterator<Item = &'a mut T>
  57. where
  58. T: 'a,
  59. Self: 'a;
  60. fn root(&self) -> NodeId;
  61. fn contains(&self, id: NodeId) -> bool {
  62. self.get(id).is_some()
  63. }
  64. fn get(&self, id: NodeId) -> Option<&T>;
  65. fn get_unchecked(&self, id: NodeId) -> &T {
  66. unsafe { self.get(id).unwrap_unchecked() }
  67. }
  68. fn get_mut(&mut self, id: NodeId) -> Option<&mut T>;
  69. fn get_mut_unchecked(&mut self, id: NodeId) -> &mut T {
  70. unsafe { self.get_mut(id).unwrap_unchecked() }
  71. }
  72. fn children(&self, id: NodeId) -> Option<Self::Iterator<'_>>;
  73. fn children_mut(&mut self, id: NodeId) -> Option<Self::IteratorMut<'_>>;
  74. fn parent_child_mut(&mut self, id: NodeId) -> Option<(&mut T, Self::IteratorMut<'_>)>;
  75. fn children_ids(&self, id: NodeId) -> Option<&[NodeId]>;
  76. fn parent(&self, id: NodeId) -> Option<&T>;
  77. fn parent_mut(&mut self, id: NodeId) -> Option<&mut T>;
  78. fn node_parent_mut(&mut self, id: NodeId) -> Option<(&mut T, Option<&mut T>)>;
  79. fn parent_id(&self, id: NodeId) -> Option<NodeId>;
  80. fn height(&self, id: NodeId) -> Option<u16>;
  81. fn size(&self) -> usize;
  82. fn traverse_depth_first(&self, mut f: impl FnMut(&T)) {
  83. let mut stack = vec![self.root()];
  84. while let Some(id) = stack.pop() {
  85. if let Some(node) = self.get(id) {
  86. f(node);
  87. if let Some(children) = self.children_ids(id) {
  88. stack.extend(children.iter().copied().rev());
  89. }
  90. }
  91. }
  92. }
  93. fn traverse_depth_first_mut(&mut self, mut f: impl FnMut(&mut T)) {
  94. let mut stack = vec![self.root()];
  95. while let Some(id) = stack.pop() {
  96. if let Some(node) = self.get_mut(id) {
  97. f(node);
  98. if let Some(children) = self.children_ids(id) {
  99. stack.extend(children.iter().copied().rev());
  100. }
  101. }
  102. }
  103. }
  104. fn traverse_breadth_first(&self, mut f: impl FnMut(&T)) {
  105. let mut queue = VecDeque::new();
  106. queue.push_back(self.root());
  107. while let Some(id) = queue.pop_front() {
  108. if let Some(node) = self.get(id) {
  109. f(node);
  110. if let Some(children) = self.children_ids(id) {
  111. for id in children {
  112. queue.push_back(*id);
  113. }
  114. }
  115. }
  116. }
  117. }
  118. fn traverse_breadth_first_mut(&mut self, mut f: impl FnMut(&mut T)) {
  119. let mut queue = VecDeque::new();
  120. queue.push_back(self.root());
  121. while let Some(id) = queue.pop_front() {
  122. if let Some(node) = self.get_mut(id) {
  123. f(node);
  124. if let Some(children) = self.children_ids(id) {
  125. for id in children {
  126. queue.push_back(*id);
  127. }
  128. }
  129. }
  130. }
  131. }
  132. }
  133. pub trait TreeLike<T>: TreeView<T> {
  134. fn new(root: T) -> Self;
  135. fn create_node(&mut self, value: T) -> NodeId;
  136. fn add_child(&mut self, parent: NodeId, child: NodeId);
  137. fn remove(&mut self, id: NodeId) -> Option<T>;
  138. fn remove_all_children(&mut self, id: NodeId) -> Vec<T>;
  139. fn replace(&mut self, old: NodeId, new: NodeId);
  140. fn insert_before(&mut self, id: NodeId, new: NodeId);
  141. fn insert_after(&mut self, id: NodeId, new: NodeId);
  142. }
  143. pub struct ChildNodeIterator<'a, T> {
  144. nodes: &'a Slab<Node<T>>,
  145. children_ids: Vec<NodeId>,
  146. index: usize,
  147. node_type: PhantomData<T>,
  148. }
  149. impl<'a, T: 'a> Iterator for ChildNodeIterator<'a, T> {
  150. type Item = &'a T;
  151. fn next(&mut self) -> Option<Self::Item> {
  152. self.children_ids.get(self.index).map(|id| {
  153. self.index += 1;
  154. &self.nodes.get(id.0).unwrap().value
  155. })
  156. }
  157. }
  158. pub struct ChildNodeIteratorMut<'a, T> {
  159. nodes: Vec<&'a mut Node<T>>,
  160. }
  161. impl<'a, T: 'a> Iterator for ChildNodeIteratorMut<'a, T> {
  162. type Item = &'a mut T;
  163. fn next(&mut self) -> Option<Self::Item> {
  164. self.nodes.pop().map(|node| &mut node.value)
  165. }
  166. }
  167. impl<T> TreeView<T> for Tree<T> {
  168. type Iterator<'a> = ChildNodeIterator<'a, T> where T: 'a;
  169. type IteratorMut<'a> = ChildNodeIteratorMut<'a, T> where T: 'a;
  170. fn root(&self) -> NodeId {
  171. self.root
  172. }
  173. fn get(&self, id: NodeId) -> Option<&T> {
  174. self.nodes.get(id.0).map(|node| &node.value)
  175. }
  176. fn get_mut(&mut self, id: NodeId) -> Option<&mut T> {
  177. self.nodes.get_mut(id.0).map(|node| &mut node.value)
  178. }
  179. fn children(&self, id: NodeId) -> Option<Self::Iterator<'_>> {
  180. self.children_ids(id).map(|children_ids| ChildNodeIterator {
  181. nodes: &self.nodes,
  182. children_ids: children_ids.to_vec(),
  183. index: 0,
  184. node_type: PhantomData,
  185. })
  186. }
  187. fn children_mut(&mut self, id: NodeId) -> Option<Self::IteratorMut<'_>> {
  188. // Safety: No node has itself as a parent.
  189. if let Some(children_ids) = self.children_ids(id) {
  190. let children_ids = children_ids.to_vec();
  191. Some(ChildNodeIteratorMut {
  192. nodes: unsafe {
  193. self.nodes
  194. .get_many_mut_unchecked(children_ids.into_iter().rev().map(|id| id.0))
  195. .unwrap()
  196. },
  197. })
  198. } else {
  199. None
  200. }
  201. }
  202. fn children_ids(&self, id: NodeId) -> Option<&[NodeId]> {
  203. self.nodes.get(id.0).map(|node| node.children.as_slice())
  204. }
  205. fn parent(&self, id: NodeId) -> Option<&T> {
  206. self.nodes
  207. .get(id.0)
  208. .and_then(|node| node.parent.map(|id| self.nodes.get(id.0).unwrap()))
  209. .map(|node| &node.value)
  210. }
  211. fn parent_mut(&mut self, id: NodeId) -> Option<&mut T> {
  212. let self_ptr = self as *mut Self;
  213. unsafe {
  214. // Safety: No node has itself as a parent.
  215. self.nodes
  216. .get_mut(id.0)
  217. .and_then(move |node| {
  218. node.parent
  219. .map(move |id| (*self_ptr).nodes.get_mut(id.0).unwrap())
  220. })
  221. .map(|node| &mut node.value)
  222. }
  223. }
  224. fn parent_id(&self, id: NodeId) -> Option<NodeId> {
  225. self.nodes.get(id.0).and_then(|node| node.parent)
  226. }
  227. fn height(&self, id: NodeId) -> Option<u16> {
  228. self.nodes.get(id.0).map(|n| n.height)
  229. }
  230. fn get_unchecked(&self, id: NodeId) -> &T {
  231. unsafe { &self.nodes.get_unchecked(id.0).value }
  232. }
  233. fn get_mut_unchecked(&mut self, id: NodeId) -> &mut T {
  234. unsafe { &mut self.nodes.get_unchecked_mut(id.0).value }
  235. }
  236. fn size(&self) -> usize {
  237. self.nodes.len()
  238. }
  239. fn node_parent_mut(&mut self, id: NodeId) -> Option<(&mut T, Option<&mut T>)> {
  240. if let Some(parent_id) = self.parent_id(id) {
  241. self.nodes
  242. .get2_mut(id.0, parent_id.0)
  243. .map(|(node, parent)| (&mut node.value, Some(&mut parent.value)))
  244. } else {
  245. self.nodes.get_mut(id.0).map(|node| (&mut node.value, None))
  246. }
  247. }
  248. fn parent_child_mut(&mut self, id: NodeId) -> Option<(&mut T, Self::IteratorMut<'_>)> {
  249. // Safety: No node will appear as a child twice
  250. if let Some(children_ids) = self.children_ids(id) {
  251. debug_assert!(!children_ids.iter().any(|child_id| *child_id == id));
  252. let mut borrowed = unsafe {
  253. let as_vec = children_ids.to_vec();
  254. self.nodes
  255. .get_many_mut_unchecked(
  256. as_vec
  257. .into_iter()
  258. .rev()
  259. .map(|id| id.0)
  260. .chain(std::iter::once(id.0)),
  261. )
  262. .unwrap()
  263. };
  264. let node = &mut borrowed.pop().unwrap().value;
  265. Some((node, ChildNodeIteratorMut { nodes: borrowed }))
  266. } else {
  267. None
  268. }
  269. }
  270. }
  271. impl<T> TreeLike<T> for Tree<T> {
  272. fn new(root: T) -> Self {
  273. let mut nodes = Slab::default();
  274. let root = NodeId(nodes.insert(Node {
  275. value: root,
  276. parent: None,
  277. children: Vec::new(),
  278. height: 0,
  279. }));
  280. Self { nodes, root }
  281. }
  282. fn create_node(&mut self, value: T) -> NodeId {
  283. NodeId(self.nodes.insert(Node {
  284. value,
  285. parent: None,
  286. children: Vec::new(),
  287. height: 0,
  288. }))
  289. }
  290. fn add_child(&mut self, parent: NodeId, new: NodeId) {
  291. self.nodes.get_mut(new.0).unwrap().parent = Some(parent);
  292. let parent = self.nodes.get_mut(parent.0).unwrap();
  293. parent.children.push(new);
  294. let height = parent.height + 1;
  295. self.set_height(new, height);
  296. }
  297. fn remove(&mut self, id: NodeId) -> Option<T> {
  298. self.try_remove(id).map(|node| node.value)
  299. }
  300. fn remove_all_children(&mut self, id: NodeId) -> Vec<T> {
  301. let mut children = Vec::new();
  302. let self_mut = self as *mut Self;
  303. for child in self.children_ids(id).unwrap() {
  304. unsafe {
  305. // Safety: No node has itself as a child
  306. children.push((*self_mut).remove(*child).unwrap());
  307. }
  308. }
  309. children
  310. }
  311. fn replace(&mut self, old_id: NodeId, new_id: NodeId) {
  312. // remove the old node
  313. let old = self
  314. .try_remove(old_id)
  315. .expect("tried to replace a node that doesn't exist");
  316. // update the parent's link to the child
  317. if let Some(parent_id) = old.parent {
  318. let parent = self.nodes.get_mut(parent_id.0).unwrap();
  319. for id in &mut parent.children {
  320. if *id == old_id {
  321. *id = new_id;
  322. }
  323. }
  324. let height = parent.height + 1;
  325. self.set_height(new_id, height);
  326. }
  327. }
  328. fn insert_before(&mut self, id: NodeId, new: NodeId) {
  329. let node = self.nodes.get(id.0).unwrap();
  330. let parent_id = node.parent.expect("tried to insert before root");
  331. self.nodes.get_mut(new.0).unwrap().parent = Some(parent_id);
  332. let parent = self.nodes.get_mut(parent_id.0).unwrap();
  333. let index = parent
  334. .children
  335. .iter()
  336. .position(|child| child == &id)
  337. .unwrap();
  338. parent.children.insert(index, new);
  339. let height = parent.height + 1;
  340. self.set_height(new, height);
  341. }
  342. fn insert_after(&mut self, id: NodeId, new: NodeId) {
  343. let node = self.nodes.get(id.0).unwrap();
  344. let parent_id = node.parent.expect("tried to insert before root");
  345. self.nodes.get_mut(new.0).unwrap().parent = Some(parent_id);
  346. let parent = self.nodes.get_mut(parent_id.0).unwrap();
  347. let index = parent
  348. .children
  349. .iter()
  350. .position(|child| child == &id)
  351. .unwrap();
  352. parent.children.insert(index + 1, new);
  353. let height = parent.height + 1;
  354. self.set_height(new, height);
  355. }
  356. }
  357. #[test]
  358. fn creation() {
  359. let mut tree = Tree::new(1);
  360. let parent = tree.root();
  361. let child = tree.create_node(0);
  362. tree.add_child(parent, child);
  363. println!("Tree: {:#?}", tree);
  364. assert_eq!(tree.size(), 2);
  365. assert_eq!(tree.height(parent), Some(0));
  366. assert_eq!(tree.height(child), Some(1));
  367. assert_eq!(*tree.get(parent).unwrap(), 1);
  368. assert_eq!(*tree.get(child).unwrap(), 0);
  369. assert_eq!(tree.parent_id(parent), None);
  370. assert_eq!(tree.parent_id(child).unwrap(), parent);
  371. assert_eq!(tree.children_ids(parent).unwrap(), &[child]);
  372. }
  373. #[test]
  374. fn insertion() {
  375. let mut tree = Tree::new(0);
  376. let parent = tree.root();
  377. let child = tree.create_node(2);
  378. tree.add_child(parent, child);
  379. let before = tree.create_node(1);
  380. tree.insert_before(child, before);
  381. let after = tree.create_node(3);
  382. tree.insert_after(child, after);
  383. println!("Tree: {:#?}", tree);
  384. assert_eq!(tree.size(), 4);
  385. assert_eq!(tree.height(parent), Some(0));
  386. assert_eq!(tree.height(child), Some(1));
  387. assert_eq!(tree.height(before), Some(1));
  388. assert_eq!(tree.height(after), Some(1));
  389. assert_eq!(*tree.get(parent).unwrap(), 0);
  390. assert_eq!(*tree.get(before).unwrap(), 1);
  391. assert_eq!(*tree.get(child).unwrap(), 2);
  392. assert_eq!(*tree.get(after).unwrap(), 3);
  393. assert_eq!(tree.parent_id(before).unwrap(), parent);
  394. assert_eq!(tree.parent_id(child).unwrap(), parent);
  395. assert_eq!(tree.parent_id(after).unwrap(), parent);
  396. assert_eq!(tree.children_ids(parent).unwrap(), &[before, child, after]);
  397. }
  398. #[test]
  399. fn deletion() {
  400. let mut tree = Tree::new(0);
  401. let parent = tree.root();
  402. let child = tree.create_node(2);
  403. tree.add_child(parent, child);
  404. let before = tree.create_node(1);
  405. tree.insert_before(child, before);
  406. let after = tree.create_node(3);
  407. tree.insert_after(child, after);
  408. println!("Tree: {:#?}", tree);
  409. assert_eq!(tree.size(), 4);
  410. assert_eq!(tree.height(parent), Some(0));
  411. assert_eq!(tree.height(child), Some(1));
  412. assert_eq!(tree.height(before), Some(1));
  413. assert_eq!(tree.height(after), Some(1));
  414. assert_eq!(*tree.get(parent).unwrap(), 0);
  415. assert_eq!(*tree.get(before).unwrap(), 1);
  416. assert_eq!(*tree.get(child).unwrap(), 2);
  417. assert_eq!(*tree.get(after).unwrap(), 3);
  418. assert_eq!(tree.parent_id(before).unwrap(), parent);
  419. assert_eq!(tree.parent_id(child).unwrap(), parent);
  420. assert_eq!(tree.parent_id(after).unwrap(), parent);
  421. assert_eq!(tree.children_ids(parent).unwrap(), &[before, child, after]);
  422. tree.remove(child);
  423. println!("Tree: {:#?}", tree);
  424. assert_eq!(tree.size(), 3);
  425. assert_eq!(tree.height(parent), Some(0));
  426. assert_eq!(tree.height(before), Some(1));
  427. assert_eq!(tree.height(after), Some(1));
  428. assert_eq!(*tree.get(parent).unwrap(), 0);
  429. assert_eq!(*tree.get(before).unwrap(), 1);
  430. assert_eq!(tree.get(child), None);
  431. assert_eq!(*tree.get(after).unwrap(), 3);
  432. assert_eq!(tree.parent_id(before).unwrap(), parent);
  433. assert_eq!(tree.parent_id(after).unwrap(), parent);
  434. assert_eq!(tree.children_ids(parent).unwrap(), &[before, after]);
  435. tree.remove(before);
  436. println!("Tree: {:#?}", tree);
  437. assert_eq!(tree.size(), 2);
  438. assert_eq!(tree.height(parent), Some(0));
  439. assert_eq!(tree.height(after), Some(1));
  440. assert_eq!(*tree.get(parent).unwrap(), 0);
  441. assert_eq!(tree.get(before), None);
  442. assert_eq!(*tree.get(after).unwrap(), 3);
  443. assert_eq!(tree.parent_id(after).unwrap(), parent);
  444. assert_eq!(tree.children_ids(parent).unwrap(), &[after]);
  445. tree.remove(after);
  446. println!("Tree: {:#?}", tree);
  447. assert_eq!(tree.size(), 1);
  448. assert_eq!(tree.height(parent), Some(0));
  449. assert_eq!(*tree.get(parent).unwrap(), 0);
  450. assert_eq!(tree.get(after), None);
  451. assert_eq!(tree.children_ids(parent).unwrap(), &[]);
  452. }
  453. #[test]
  454. fn traverse_depth_first() {
  455. let mut tree = Tree::new(0);
  456. let parent = tree.root();
  457. let child1 = tree.create_node(1);
  458. tree.add_child(parent, child1);
  459. let grandchild1 = tree.create_node(2);
  460. tree.add_child(child1, grandchild1);
  461. let child2 = tree.create_node(3);
  462. tree.add_child(parent, child2);
  463. let grandchild2 = tree.create_node(4);
  464. tree.add_child(child2, grandchild2);
  465. let mut node_count = 0;
  466. tree.traverse_depth_first(move |node| {
  467. assert_eq!(*node, node_count);
  468. node_count += 1;
  469. });
  470. }
  471. #[test]
  472. fn get_node_children_mut() {
  473. let mut tree = Tree::new(0);
  474. let parent = tree.root();
  475. let child1 = tree.create_node(1);
  476. tree.add_child(parent, child1);
  477. let child2 = tree.create_node(2);
  478. tree.add_child(parent, child2);
  479. let child3 = tree.create_node(3);
  480. tree.add_child(parent, child3);
  481. let (parent, children) = tree.parent_child_mut(parent).unwrap();
  482. for (i, child) in children.enumerate() {
  483. assert_eq!(*child, i + 1);
  484. }
  485. println!("Parent: {:#?}", parent);
  486. }
  487. #[test]
  488. fn get_many_mut_unchecked() {
  489. let mut slab = Slab::new();
  490. let parent = slab.insert(0);
  491. let child = slab.insert(1);
  492. let grandchild = slab.insert(2);
  493. let all =
  494. unsafe { slab.get_many_mut_unchecked([parent, child, grandchild].into_iter()) }.unwrap();
  495. println!("All: {:#?}", all);
  496. }
  497. #[derive(Debug)]
  498. struct Slab<T> {
  499. data: Vec<Option<T>>,
  500. free: VecDeque<usize>,
  501. }
  502. impl<T> Default for Slab<T> {
  503. fn default() -> Self {
  504. Self::new()
  505. }
  506. }
  507. impl<T> Slab<T> {
  508. fn new() -> Self {
  509. Self {
  510. data: Vec::new(),
  511. free: VecDeque::new(),
  512. }
  513. }
  514. fn get(&self, id: usize) -> Option<&T> {
  515. self.data.get(id).and_then(|x| x.as_ref())
  516. }
  517. unsafe fn get_unchecked(&self, id: usize) -> &T {
  518. self.data.get_unchecked(id).as_ref().unwrap()
  519. }
  520. fn get_mut(&mut self, id: usize) -> Option<&mut T> {
  521. self.data.get_mut(id).and_then(|x| x.as_mut())
  522. }
  523. unsafe fn get_unchecked_mut(&mut self, id: usize) -> &mut T {
  524. self.data.get_unchecked_mut(id).as_mut().unwrap()
  525. }
  526. fn get2_mut(&mut self, id1: usize, id2: usize) -> Option<(&mut T, &mut T)> {
  527. assert!(id1 != id2);
  528. let ptr = self.data.as_mut_ptr();
  529. let first = unsafe { &mut *ptr.add(id1) };
  530. let second = unsafe { &mut *ptr.add(id2) };
  531. if let (Some(first), Some(second)) = (first, second) {
  532. Some((first, second))
  533. } else {
  534. None
  535. }
  536. }
  537. unsafe fn get_many_mut_unchecked(
  538. &mut self,
  539. ids: impl Iterator<Item = usize>,
  540. ) -> Option<Vec<&mut T>> {
  541. let ptr = self.data.as_mut_ptr();
  542. let mut result = Vec::new();
  543. for id in ids {
  544. let item = unsafe { &mut *ptr.add(id) };
  545. if let Some(item) = item {
  546. result.push(item);
  547. } else {
  548. return None;
  549. }
  550. }
  551. Some(result)
  552. }
  553. fn insert(&mut self, value: T) -> usize {
  554. if let Some(id) = self.free.pop_front() {
  555. self.data[id] = Some(value);
  556. id
  557. } else {
  558. self.data.push(Some(value));
  559. self.data.len() - 1
  560. }
  561. }
  562. fn try_remove(&mut self, id: usize) -> Option<T> {
  563. self.data.get_mut(id).and_then(|x| {
  564. self.free.push_back(id);
  565. x.take()
  566. })
  567. }
  568. fn remove(&mut self, id: usize) -> T {
  569. self.try_remove(id).unwrap()
  570. }
  571. fn len(&self) -> usize {
  572. self.data.len() - self.free.len()
  573. }
  574. }