lib.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. extern crate proc_macro;
  2. use std::collections::HashSet;
  3. use proc_macro::TokenStream;
  4. use quote::{format_ident, quote};
  5. use syn::{parse_macro_input, ItemImpl, Type, TypePath, TypeTuple};
  6. /// A helper attribute for deriving `State` for a struct.
  7. #[proc_macro_attribute]
  8. pub fn partial_derive_state(_: TokenStream, input: TokenStream) -> TokenStream {
  9. let impl_block: syn::ItemImpl = parse_macro_input!(input as syn::ItemImpl);
  10. let has_create_fn = impl_block
  11. .items
  12. .iter()
  13. .any(|item| matches!(item, syn::ImplItem::Method(method) if method.sig.ident == "create"));
  14. let parent_dependencies = impl_block
  15. .items
  16. .iter()
  17. .find_map(|item| {
  18. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  19. (ident == "ParentDependencies").then_some(ty)
  20. } else {
  21. None
  22. }
  23. })
  24. .expect("ParentDependencies must be defined");
  25. let child_dependencies = impl_block
  26. .items
  27. .iter()
  28. .find_map(|item| {
  29. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  30. (ident == "ChildDependencies").then_some(ty)
  31. } else {
  32. None
  33. }
  34. })
  35. .expect("ChildDependencies must be defined");
  36. let node_dependencies = impl_block
  37. .items
  38. .iter()
  39. .find_map(|item| {
  40. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  41. (ident == "NodeDependencies").then_some(ty)
  42. } else {
  43. None
  44. }
  45. })
  46. .expect("NodeDependencies must be defined");
  47. let this_type = &impl_block.self_ty;
  48. let this_type = extract_type_path(this_type)
  49. .unwrap_or_else(|| panic!("Self must be a type path, found {}", quote!(#this_type)));
  50. let mut combined_dependencies = HashSet::new();
  51. let self_path: TypePath = syn::parse_quote!(Self);
  52. let parent_dependencies = match extract_tuple(parent_dependencies) {
  53. Some(tuple) => {
  54. let mut parent_dependencies = Vec::new();
  55. for type_ in &tuple.elems {
  56. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  57. panic!(
  58. "ParentDependencies must be a tuple of type paths, found {}",
  59. quote!(#type_)
  60. )
  61. });
  62. if type_ == self_path {
  63. type_ = this_type.clone();
  64. }
  65. combined_dependencies.insert(type_.clone());
  66. parent_dependencies.push(type_);
  67. }
  68. parent_dependencies
  69. }
  70. _ => panic!(
  71. "ParentDependencies must be a tuple, found {}",
  72. quote!(#parent_dependencies)
  73. ),
  74. };
  75. let child_dependencies = match extract_tuple(child_dependencies) {
  76. Some(tuple) => {
  77. let mut child_dependencies = Vec::new();
  78. for type_ in &tuple.elems {
  79. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  80. panic!(
  81. "ChildDependencies must be a tuple of type paths, found {}",
  82. quote!(#type_)
  83. )
  84. });
  85. if type_ == self_path {
  86. type_ = this_type.clone();
  87. }
  88. combined_dependencies.insert(type_.clone());
  89. child_dependencies.push(type_);
  90. }
  91. child_dependencies
  92. }
  93. _ => panic!(
  94. "ChildDependencies must be a tuple, found {}",
  95. quote!(#child_dependencies)
  96. ),
  97. };
  98. let node_dependencies = match extract_tuple(node_dependencies) {
  99. Some(tuple) => {
  100. let mut node_dependencies = Vec::new();
  101. for type_ in &tuple.elems {
  102. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  103. panic!(
  104. "NodeDependencies must be a tuple of type paths, found {}",
  105. quote!(#type_)
  106. )
  107. });
  108. if type_ == self_path {
  109. type_ = this_type.clone();
  110. }
  111. combined_dependencies.insert(type_.clone());
  112. node_dependencies.push(type_);
  113. }
  114. node_dependencies
  115. }
  116. _ => panic!(
  117. "NodeDependencies must be a tuple, found {}",
  118. quote!(#node_dependencies)
  119. ),
  120. };
  121. combined_dependencies.insert(this_type.clone());
  122. let combined_dependencies: Vec<_> = combined_dependencies.into_iter().collect();
  123. let parent_dependancies_idxes: Vec<_> = parent_dependencies
  124. .iter()
  125. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  126. .collect();
  127. let child_dependencies_idxes: Vec<_> = child_dependencies
  128. .iter()
  129. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  130. .collect();
  131. let node_dependencies_idxes: Vec<_> = node_dependencies
  132. .iter()
  133. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  134. .collect();
  135. let this_type_idx = combined_dependencies
  136. .iter()
  137. .enumerate()
  138. .find_map(|(i, ident)| (this_type == *ident).then_some(i))
  139. .unwrap();
  140. let this_view = format_ident!("__data{}", this_type_idx);
  141. let combined_dependencies_quote = combined_dependencies.iter().map(|ident| {
  142. if ident == &this_type {
  143. quote! {shipyard::ViewMut<#ident>}
  144. } else {
  145. quote! {shipyard::View<#ident>}
  146. }
  147. });
  148. let combined_dependencies_quote = quote!((#(#combined_dependencies_quote,)*));
  149. let ItemImpl {
  150. attrs,
  151. defaultness,
  152. unsafety,
  153. impl_token,
  154. generics,
  155. trait_,
  156. self_ty,
  157. items,
  158. ..
  159. } = impl_block;
  160. let for_ = trait_.as_ref().map(|t| t.2);
  161. let trait_ = trait_.map(|t| t.1);
  162. let split_views: Vec<_> = (0..combined_dependencies.len())
  163. .map(|i| {
  164. let ident = format_ident!("__data{}", i);
  165. if i == this_type_idx {
  166. quote! {mut #ident}
  167. } else {
  168. quote! {#ident}
  169. }
  170. })
  171. .collect();
  172. let node_view = node_dependencies_idxes
  173. .iter()
  174. .map(|i| format_ident!("__data{}", i))
  175. .collect::<Vec<_>>();
  176. let get_node_view = {
  177. if node_dependencies.is_empty() {
  178. quote! {
  179. let raw_node = ();
  180. }
  181. } else {
  182. let temps = (0..node_dependencies.len())
  183. .map(|i| format_ident!("__temp{}", i))
  184. .collect::<Vec<_>>();
  185. quote! {
  186. let raw_node: (#(*const #node_dependencies,)*) = {
  187. let (#(#temps,)*) = (#(&#node_view,)*).get(id).unwrap_or_else(|err| panic!("Failed to get node view {:?}", err));
  188. (#(#temps as *const _,)*)
  189. };
  190. }
  191. }
  192. };
  193. let deref_node_view = {
  194. if node_dependencies.is_empty() {
  195. quote! {
  196. let node = raw_node;
  197. }
  198. } else {
  199. let indexes = (0..node_dependencies.len()).map(syn::Index::from);
  200. quote! {
  201. let node = unsafe { (#(dioxus_native_core::prelude::DependancyView::new(&*raw_node.#indexes),)*) };
  202. }
  203. }
  204. };
  205. let parent_view = parent_dependancies_idxes
  206. .iter()
  207. .map(|i| format_ident!("__data{}", i))
  208. .collect::<Vec<_>>();
  209. let get_parent_view = {
  210. if parent_dependencies.is_empty() {
  211. quote! {
  212. let raw_parent = tree.parent_id(id).map(|_| ());
  213. }
  214. } else {
  215. let temps = (0..parent_dependencies.len())
  216. .map(|i| format_ident!("__temp{}", i))
  217. .collect::<Vec<_>>();
  218. quote! {
  219. let raw_parent = tree.parent_id(id).and_then(|parent_id| {
  220. let raw_parent: Option<(#(*const #parent_dependencies,)*)> = (#(&#parent_view,)*).get(parent_id).ok().map(|c| {
  221. let (#(#temps,)*) = c;
  222. (#(#temps as *const _,)*)
  223. });
  224. raw_parent
  225. });
  226. }
  227. }
  228. };
  229. let deref_parent_view = {
  230. if parent_dependencies.is_empty() {
  231. quote! {
  232. let parent = raw_parent;
  233. }
  234. } else {
  235. let indexes = (0..parent_dependencies.len()).map(syn::Index::from);
  236. quote! {
  237. let parent = unsafe { raw_parent.map(|raw_parent| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_parent.#indexes),)*)) };
  238. }
  239. }
  240. };
  241. let child_view = child_dependencies_idxes
  242. .iter()
  243. .map(|i| format_ident!("__data{}", i))
  244. .collect::<Vec<_>>();
  245. let get_child_view = {
  246. if child_dependencies.is_empty() {
  247. quote! {
  248. let raw_children: Vec<_> = tree.children_ids(id).into_iter().map(|_| ()).collect();
  249. }
  250. } else {
  251. let temps = (0..child_dependencies.len())
  252. .map(|i| format_ident!("__temp{}", i))
  253. .collect::<Vec<_>>();
  254. quote! {
  255. let raw_children: Vec<_> = tree.children_ids(id).into_iter().filter_map(|id| {
  256. let raw_children: Option<(#(*const #child_dependencies,)*)> = (#(&#child_view,)*).get(id).ok().map(|c| {
  257. let (#(#temps,)*) = c;
  258. (#(#temps as *const _,)*)
  259. });
  260. raw_children
  261. }).collect();
  262. }
  263. }
  264. };
  265. let deref_child_view = {
  266. if child_dependencies.is_empty() {
  267. quote! {
  268. let children = raw_children;
  269. }
  270. } else {
  271. let indexes = (0..child_dependencies.len()).map(syn::Index::from);
  272. quote! {
  273. let children = unsafe { raw_children.iter().map(|raw_children| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_children.#indexes),)*)).collect::<Vec<_>>() };
  274. }
  275. }
  276. };
  277. let trait_generics = trait_
  278. .as_ref()
  279. .unwrap()
  280. .segments
  281. .last()
  282. .unwrap()
  283. .arguments
  284. .clone();
  285. // if a create function is defined, we don't generate one
  286. // otherwise we generate a default one that uses the update function and the default constructor
  287. let create_fn = (!has_create_fn).then(|| {
  288. quote! {
  289. fn create<'a>(
  290. node_view: dioxus_native_core::prelude::NodeView # trait_generics,
  291. node: <Self::NodeDependencies as Dependancy>::ElementBorrowed<'a>,
  292. parent: Option<<Self::ParentDependencies as Dependancy>::ElementBorrowed<'a>>,
  293. children: Vec<<Self::ChildDependencies as Dependancy>::ElementBorrowed<'a>>,
  294. context: &dioxus_native_core::prelude::SendAnyMap,
  295. ) -> Self {
  296. let mut myself = Self::default();
  297. myself.update(node_view, node, parent, children, context);
  298. myself
  299. }
  300. }
  301. });
  302. quote!(
  303. #(#attrs)*
  304. #defaultness #unsafety #impl_token #generics #trait_ #for_ #self_ty {
  305. #create_fn
  306. #(#items)*
  307. fn workload_system(type_id: std::any::TypeId, dependants: std::sync::Arc<dioxus_native_core::prelude::Dependants>, pass_direction: dioxus_native_core::prelude::PassDirection) -> dioxus_native_core::exports::shipyard::WorkloadSystem {
  308. use dioxus_native_core::exports::shipyard::{IntoWorkloadSystem, Get, AddComponent};
  309. use dioxus_native_core::tree::TreeRef;
  310. use dioxus_native_core::prelude::{NodeType, NodeView};
  311. let node_mask = Self::NODE_MASK.build();
  312. (move |data: #combined_dependencies_quote, run_view: dioxus_native_core::prelude::RunPassView #trait_generics| {
  313. let (#(#split_views,)*) = data;
  314. let tree = run_view.tree.clone();
  315. let node_types = run_view.node_type.clone();
  316. dioxus_native_core::prelude::run_pass(type_id, dependants.clone(), pass_direction, run_view, |id, context| {
  317. let node_data: &NodeType<_> = node_types.get(id).unwrap_or_else(|err| panic!("Failed to get node type {:?}", err));
  318. // get all of the states from the tree view
  319. // Safety: No node has itself as a parent or child.
  320. let raw_myself: Option<*mut Self> = (&mut #this_view).get(id).ok().map(|c| c as *mut _);
  321. #get_node_view
  322. #get_parent_view
  323. #get_child_view
  324. let myself: Option<&mut Self> = unsafe { raw_myself.map(|val| &mut *val) };
  325. #deref_node_view
  326. #deref_parent_view
  327. #deref_child_view
  328. let view = NodeView::new(id, node_data, &node_mask);
  329. if let Some(myself) = myself {
  330. myself
  331. .update(view, node, parent, children, context)
  332. }
  333. else {
  334. (&mut #this_view).add_component_unchecked(
  335. id,
  336. Self::create(view, node, parent, children, context));
  337. true
  338. }
  339. })
  340. }).into_workload_system().unwrap()
  341. }
  342. }
  343. )
  344. .into()
  345. }
  346. fn extract_tuple(ty: &Type) -> Option<TypeTuple> {
  347. match ty {
  348. Type::Tuple(tuple) => Some(tuple.clone()),
  349. Type::Group(group) => extract_tuple(&group.elem),
  350. _ => None,
  351. }
  352. }
  353. fn extract_type_path(ty: &Type) -> Option<TypePath> {
  354. match ty {
  355. Type::Path(path) => Some(path.clone()),
  356. Type::Group(group) => extract_type_path(&group.elem),
  357. _ => None,
  358. }
  359. }