tree.rs 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. //! A tree of nodes intigated with shipyard
  2. use crate::NodeId;
  3. use shipyard::{Component, EntitiesViewMut, Get, View, ViewMut};
  4. use std::fmt::Debug;
  5. /// A node in a tree.
  6. #[derive(PartialEq, Eq, Clone, Debug, Component)]
  7. pub struct Node {
  8. parent: Option<NodeId>,
  9. children: Vec<NodeId>,
  10. height: u16,
  11. }
  12. /// A view of a tree.
  13. pub type TreeRefView<'a> = View<'a, Node>;
  14. /// A mutable view of a tree.
  15. pub type TreeMutView<'a> = (EntitiesViewMut<'a>, ViewMut<'a, Node>);
  16. /// A immutable view of a tree.
  17. pub trait TreeRef {
  18. /// The parent id of the node.
  19. fn parent_id(&self, id: NodeId) -> Option<NodeId>;
  20. /// The children ids of the node.
  21. fn children_ids(&self, id: NodeId) -> Vec<NodeId>;
  22. /// The height of the node.
  23. fn height(&self, id: NodeId) -> Option<u16>;
  24. /// Returns true if the node exists.
  25. fn contains(&self, id: NodeId) -> bool;
  26. }
  27. /// A mutable view of a tree.
  28. pub trait TreeMut: TreeRef {
  29. /// Removes the node and all of its children.
  30. fn remove(&mut self, id: NodeId);
  31. /// Removes the node and all of its children.
  32. fn remove_single(&mut self, id: NodeId);
  33. /// Adds a new node to the tree.
  34. fn create_node(&mut self, id: NodeId);
  35. /// Adds a child to the node.
  36. fn add_child(&mut self, parent: NodeId, new: NodeId);
  37. /// Replaces the node with a new node.
  38. fn replace(&mut self, old_id: NodeId, new_id: NodeId);
  39. /// Inserts a node before another node.
  40. fn insert_before(&mut self, old_id: NodeId, new_id: NodeId);
  41. /// Inserts a node after another node.
  42. fn insert_after(&mut self, old_id: NodeId, new_id: NodeId);
  43. }
  44. impl<'a> TreeRef for TreeRefView<'a> {
  45. fn parent_id(&self, id: NodeId) -> Option<NodeId> {
  46. self.get(id).ok()?.parent
  47. }
  48. fn children_ids(&self, id: NodeId) -> Vec<NodeId> {
  49. self.get(id)
  50. .map(|node| node.children.clone())
  51. .unwrap_or_default()
  52. }
  53. fn height(&self, id: NodeId) -> Option<u16> {
  54. Some(self.get(id).ok()?.height)
  55. }
  56. fn contains(&self, id: NodeId) -> bool {
  57. self.get(id).is_ok()
  58. }
  59. }
  60. impl<'a> TreeMut for TreeMutView<'a> {
  61. fn remove(&mut self, id: NodeId) {
  62. fn recurse(tree: &mut TreeMutView<'_>, id: NodeId) {
  63. let children = tree.children_ids(id);
  64. for child in children {
  65. recurse(tree, child);
  66. }
  67. }
  68. {
  69. let mut node_data_mut = &mut self.1;
  70. if let Some(parent) = node_data_mut.get(id).unwrap().parent {
  71. let parent = (&mut node_data_mut).get(parent).unwrap();
  72. parent.children.retain(|&child| child != id);
  73. }
  74. }
  75. recurse(self, id);
  76. }
  77. fn remove_single(&mut self, id: NodeId) {
  78. {
  79. let mut node_data_mut = &mut self.1;
  80. if let Some(parent) = node_data_mut.get(id).unwrap().parent {
  81. let parent = (&mut node_data_mut).get(parent).unwrap();
  82. parent.children.retain(|&child| child != id);
  83. }
  84. }
  85. }
  86. fn create_node(&mut self, id: NodeId) {
  87. let (entities, node_data_mut) = self;
  88. entities.add_component(
  89. id,
  90. node_data_mut,
  91. Node {
  92. parent: None,
  93. children: Vec::new(),
  94. height: 0,
  95. },
  96. );
  97. }
  98. fn add_child(&mut self, parent: NodeId, new: NodeId) {
  99. let height;
  100. {
  101. let mut node_state = &mut self.1;
  102. (&mut node_state).get(new).unwrap().parent = Some(parent);
  103. let parent = (&mut node_state).get(parent).unwrap();
  104. parent.children.push(new);
  105. height = parent.height + 1;
  106. }
  107. set_height(self, new, height);
  108. }
  109. fn replace(&mut self, old_id: NodeId, new_id: NodeId) {
  110. {
  111. let mut node_state = &mut self.1;
  112. // update the parent's link to the child
  113. if let Some(parent_id) = node_state.get(old_id).unwrap().parent {
  114. let parent = (&mut node_state).get(parent_id).unwrap();
  115. for id in &mut parent.children {
  116. if *id == old_id {
  117. *id = new_id;
  118. break;
  119. }
  120. }
  121. let height = parent.height + 1;
  122. set_height(self, new_id, height);
  123. }
  124. }
  125. // remove the old node
  126. self.remove(old_id);
  127. }
  128. fn insert_before(&mut self, old_id: NodeId, new_id: NodeId) {
  129. let mut node_state = &mut self.1;
  130. let old_node = node_state.get(old_id).unwrap();
  131. let parent_id = old_node.parent.expect("tried to insert before root");
  132. (&mut node_state).get(new_id).unwrap().parent = Some(parent_id);
  133. let parent = (&mut node_state).get(parent_id).unwrap();
  134. let index = parent
  135. .children
  136. .iter()
  137. .position(|child| *child == old_id)
  138. .unwrap();
  139. parent.children.insert(index, new_id);
  140. let height = parent.height + 1;
  141. set_height(self, new_id, height);
  142. }
  143. fn insert_after(&mut self, old_id: NodeId, new_id: NodeId) {
  144. let mut node_state = &mut self.1;
  145. let old_node = node_state.get(old_id).unwrap();
  146. let parent_id = old_node.parent.expect("tried to insert before root");
  147. (&mut node_state).get(new_id).unwrap().parent = Some(parent_id);
  148. let parent = (&mut node_state).get(parent_id).unwrap();
  149. let index = parent
  150. .children
  151. .iter()
  152. .position(|child| *child == old_id)
  153. .unwrap();
  154. parent.children.insert(index + 1, new_id);
  155. let height = parent.height + 1;
  156. set_height(self, new_id, height);
  157. }
  158. }
  159. /// Sets the height of a node and updates the height of all its children
  160. fn set_height(tree: &mut TreeMutView<'_>, node: NodeId, height: u16) {
  161. let children = {
  162. let mut node_data_mut = &mut tree.1;
  163. let mut node = (&mut node_data_mut).get(node).unwrap();
  164. node.height = height;
  165. node.children.clone()
  166. };
  167. for child in children {
  168. set_height(tree, child, height + 1);
  169. }
  170. }
  171. impl<'a> TreeRef for TreeMutView<'a> {
  172. fn parent_id(&self, id: NodeId) -> Option<NodeId> {
  173. let node_data = &self.1;
  174. node_data.get(id).unwrap().parent
  175. }
  176. fn children_ids(&self, id: NodeId) -> Vec<NodeId> {
  177. let node_data = &self.1;
  178. node_data
  179. .get(id)
  180. .map(|node| node.children.clone())
  181. .unwrap_or_default()
  182. }
  183. fn height(&self, id: NodeId) -> Option<u16> {
  184. let node_data = &self.1;
  185. node_data.get(id).map(|node| node.height).ok()
  186. }
  187. fn contains(&self, id: NodeId) -> bool {
  188. self.1.get(id).is_ok()
  189. }
  190. }
  191. #[test]
  192. fn creation() {
  193. use shipyard::World;
  194. #[derive(Component)]
  195. struct Num(i32);
  196. let mut world = World::new();
  197. let parent_id = world.add_entity(Num(1i32));
  198. let child_id = world.add_entity(Num(0i32));
  199. let mut tree = world.borrow::<TreeMutView>().unwrap();
  200. tree.create_node(parent_id);
  201. tree.create_node(child_id);
  202. tree.add_child(parent_id, child_id);
  203. assert_eq!(tree.height(parent_id), Some(0));
  204. assert_eq!(tree.height(child_id), Some(1));
  205. assert_eq!(tree.parent_id(parent_id), None);
  206. assert_eq!(tree.parent_id(child_id).unwrap(), parent_id);
  207. assert_eq!(tree.children_ids(parent_id), &[child_id]);
  208. }
  209. #[test]
  210. fn insertion() {
  211. use shipyard::World;
  212. #[derive(Component)]
  213. struct Num(i32);
  214. let mut world = World::new();
  215. let parent = world.add_entity(Num(0));
  216. let child = world.add_entity(Num(2));
  217. let before = world.add_entity(Num(1));
  218. let after = world.add_entity(Num(3));
  219. let mut tree = world.borrow::<TreeMutView>().unwrap();
  220. tree.create_node(parent);
  221. tree.create_node(child);
  222. tree.create_node(before);
  223. tree.create_node(after);
  224. tree.add_child(parent, child);
  225. tree.insert_before(child, before);
  226. tree.insert_after(child, after);
  227. assert_eq!(tree.height(parent), Some(0));
  228. assert_eq!(tree.height(child), Some(1));
  229. assert_eq!(tree.height(before), Some(1));
  230. assert_eq!(tree.height(after), Some(1));
  231. assert_eq!(tree.parent_id(before).unwrap(), parent);
  232. assert_eq!(tree.parent_id(child).unwrap(), parent);
  233. assert_eq!(tree.parent_id(after).unwrap(), parent);
  234. assert_eq!(tree.children_ids(parent), &[before, child, after]);
  235. }
  236. #[test]
  237. fn deletion() {
  238. use shipyard::World;
  239. #[derive(Component)]
  240. struct Num(i32);
  241. let mut world = World::new();
  242. let parent = world.add_entity(Num(0));
  243. let child = world.add_entity(Num(2));
  244. let before = world.add_entity(Num(1));
  245. let after = world.add_entity(Num(3));
  246. let mut tree = world.borrow::<TreeMutView>().unwrap();
  247. tree.create_node(parent);
  248. tree.create_node(child);
  249. tree.create_node(before);
  250. tree.create_node(after);
  251. tree.add_child(parent, child);
  252. tree.insert_before(child, before);
  253. tree.insert_after(child, after);
  254. assert_eq!(tree.height(parent), Some(0));
  255. assert_eq!(tree.height(child), Some(1));
  256. assert_eq!(tree.height(before), Some(1));
  257. assert_eq!(tree.height(after), Some(1));
  258. assert_eq!(tree.parent_id(before).unwrap(), parent);
  259. assert_eq!(tree.parent_id(child).unwrap(), parent);
  260. assert_eq!(tree.parent_id(after).unwrap(), parent);
  261. assert_eq!(tree.children_ids(parent), &[before, child, after]);
  262. tree.remove(child);
  263. assert_eq!(tree.height(parent), Some(0));
  264. assert_eq!(tree.height(before), Some(1));
  265. assert_eq!(tree.height(after), Some(1));
  266. assert_eq!(tree.parent_id(before).unwrap(), parent);
  267. assert_eq!(tree.parent_id(after).unwrap(), parent);
  268. assert_eq!(tree.children_ids(parent), &[before, after]);
  269. tree.remove(before);
  270. assert_eq!(tree.height(parent), Some(0));
  271. assert_eq!(tree.height(after), Some(1));
  272. assert_eq!(tree.parent_id(after).unwrap(), parent);
  273. assert_eq!(tree.children_ids(parent), &[after]);
  274. tree.remove(after);
  275. assert_eq!(tree.height(parent), Some(0));
  276. assert_eq!(tree.children_ids(parent), &[]);
  277. }