lib.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. #![doc = include_str!("../README.md")]
  2. #![doc(html_logo_url = "https://avatars.githubusercontent.com/u/79236386")]
  3. #![doc(html_favicon_url = "https://avatars.githubusercontent.com/u/79236386")]
  4. extern crate proc_macro;
  5. use std::collections::HashSet;
  6. use proc_macro::TokenStream;
  7. use quote::{format_ident, quote};
  8. use syn::{parse_macro_input, ItemImpl, Type, TypePath, TypeTuple};
  9. /// A helper attribute for deriving `State` for a struct.
  10. #[proc_macro_attribute]
  11. pub fn partial_derive_state(_: TokenStream, input: TokenStream) -> TokenStream {
  12. let impl_block: syn::ItemImpl = parse_macro_input!(input as syn::ItemImpl);
  13. let has_create_fn = impl_block
  14. .items
  15. .iter()
  16. .any(|item| matches!(item, syn::ImplItem::Fn(method) if method.sig.ident == "create"));
  17. let parent_dependencies = impl_block
  18. .items
  19. .iter()
  20. .find_map(|item| {
  21. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  22. (ident == "ParentDependencies").then_some(ty)
  23. } else {
  24. None
  25. }
  26. })
  27. .expect("ParentDependencies must be defined");
  28. let child_dependencies = impl_block
  29. .items
  30. .iter()
  31. .find_map(|item| {
  32. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  33. (ident == "ChildDependencies").then_some(ty)
  34. } else {
  35. None
  36. }
  37. })
  38. .expect("ChildDependencies must be defined");
  39. let node_dependencies = impl_block
  40. .items
  41. .iter()
  42. .find_map(|item| {
  43. if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
  44. (ident == "NodeDependencies").then_some(ty)
  45. } else {
  46. None
  47. }
  48. })
  49. .expect("NodeDependencies must be defined");
  50. let this_type = &impl_block.self_ty;
  51. let this_type = extract_type_path(this_type)
  52. .unwrap_or_else(|| panic!("Self must be a type path, found {}", quote!(#this_type)));
  53. let mut combined_dependencies = HashSet::new();
  54. let self_path: TypePath = syn::parse_quote!(Self);
  55. let parent_dependencies = match extract_tuple(parent_dependencies) {
  56. Some(tuple) => {
  57. let mut parent_dependencies = Vec::new();
  58. for type_ in &tuple.elems {
  59. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  60. panic!(
  61. "ParentDependencies must be a tuple of type paths, found {}",
  62. quote!(#type_)
  63. )
  64. });
  65. if type_ == self_path {
  66. type_ = this_type.clone();
  67. }
  68. combined_dependencies.insert(type_.clone());
  69. parent_dependencies.push(type_);
  70. }
  71. parent_dependencies
  72. }
  73. _ => panic!(
  74. "ParentDependencies must be a tuple, found {}",
  75. quote!(#parent_dependencies)
  76. ),
  77. };
  78. let child_dependencies = match extract_tuple(child_dependencies) {
  79. Some(tuple) => {
  80. let mut child_dependencies = Vec::new();
  81. for type_ in &tuple.elems {
  82. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  83. panic!(
  84. "ChildDependencies must be a tuple of type paths, found {}",
  85. quote!(#type_)
  86. )
  87. });
  88. if type_ == self_path {
  89. type_ = this_type.clone();
  90. }
  91. combined_dependencies.insert(type_.clone());
  92. child_dependencies.push(type_);
  93. }
  94. child_dependencies
  95. }
  96. _ => panic!(
  97. "ChildDependencies must be a tuple, found {}",
  98. quote!(#child_dependencies)
  99. ),
  100. };
  101. let node_dependencies = match extract_tuple(node_dependencies) {
  102. Some(tuple) => {
  103. let mut node_dependencies = Vec::new();
  104. for type_ in &tuple.elems {
  105. let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
  106. panic!(
  107. "NodeDependencies must be a tuple of type paths, found {}",
  108. quote!(#type_)
  109. )
  110. });
  111. if type_ == self_path {
  112. type_ = this_type.clone();
  113. }
  114. combined_dependencies.insert(type_.clone());
  115. node_dependencies.push(type_);
  116. }
  117. node_dependencies
  118. }
  119. _ => panic!(
  120. "NodeDependencies must be a tuple, found {}",
  121. quote!(#node_dependencies)
  122. ),
  123. };
  124. combined_dependencies.insert(this_type.clone());
  125. let combined_dependencies: Vec<_> = combined_dependencies.into_iter().collect();
  126. let parent_dependancies_idxes: Vec<_> = parent_dependencies
  127. .iter()
  128. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  129. .collect();
  130. let child_dependencies_idxes: Vec<_> = child_dependencies
  131. .iter()
  132. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  133. .collect();
  134. let node_dependencies_idxes: Vec<_> = node_dependencies
  135. .iter()
  136. .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
  137. .collect();
  138. let this_type_idx = combined_dependencies
  139. .iter()
  140. .enumerate()
  141. .find_map(|(i, ident)| (this_type == *ident).then_some(i))
  142. .unwrap();
  143. let this_view = format_ident!("__data{}", this_type_idx);
  144. let combined_dependencies_quote = combined_dependencies.iter().map(|ident| {
  145. if ident == &this_type {
  146. quote! {shipyard::ViewMut<#ident>}
  147. } else {
  148. quote! {shipyard::View<#ident>}
  149. }
  150. });
  151. let combined_dependencies_quote = quote!((#(#combined_dependencies_quote,)*));
  152. let ItemImpl {
  153. attrs,
  154. defaultness,
  155. unsafety,
  156. impl_token,
  157. generics,
  158. trait_,
  159. self_ty,
  160. items,
  161. ..
  162. } = impl_block;
  163. let for_ = trait_.as_ref().map(|t| t.2);
  164. let trait_ = trait_.map(|t| t.1);
  165. let split_views: Vec<_> = (0..combined_dependencies.len())
  166. .map(|i| {
  167. let ident = format_ident!("__data{}", i);
  168. if i == this_type_idx {
  169. quote! {mut #ident}
  170. } else {
  171. quote! {#ident}
  172. }
  173. })
  174. .collect();
  175. let node_view = node_dependencies_idxes
  176. .iter()
  177. .map(|i| format_ident!("__data{}", i))
  178. .collect::<Vec<_>>();
  179. let get_node_view = {
  180. if node_dependencies.is_empty() {
  181. quote! {
  182. let raw_node = ();
  183. }
  184. } else {
  185. let temps = (0..node_dependencies.len())
  186. .map(|i| format_ident!("__temp{}", i))
  187. .collect::<Vec<_>>();
  188. quote! {
  189. let raw_node: (#(*const #node_dependencies,)*) = {
  190. let (#(#temps,)*) = (#(&#node_view,)*).get(id).unwrap_or_else(|err| panic!("Failed to get node view {:?}", err));
  191. (#(#temps as *const _,)*)
  192. };
  193. }
  194. }
  195. };
  196. let deref_node_view = {
  197. if node_dependencies.is_empty() {
  198. quote! {
  199. let node = raw_node;
  200. }
  201. } else {
  202. let indexes = (0..node_dependencies.len()).map(syn::Index::from);
  203. quote! {
  204. let node = unsafe { (#(dioxus_native_core::prelude::DependancyView::new(&*raw_node.#indexes),)*) };
  205. }
  206. }
  207. };
  208. let parent_view = parent_dependancies_idxes
  209. .iter()
  210. .map(|i| format_ident!("__data{}", i))
  211. .collect::<Vec<_>>();
  212. let get_parent_view = {
  213. if parent_dependencies.is_empty() {
  214. quote! {
  215. let raw_parent = tree.parent_id_advanced(id, Self::TRAVERSE_SHADOW_DOM).map(|_| ());
  216. }
  217. } else {
  218. let temps = (0..parent_dependencies.len())
  219. .map(|i| format_ident!("__temp{}", i))
  220. .collect::<Vec<_>>();
  221. quote! {
  222. let raw_parent = tree.parent_id_advanced(id, Self::TRAVERSE_SHADOW_DOM).and_then(|parent_id| {
  223. let raw_parent: Option<(#(*const #parent_dependencies,)*)> = (#(&#parent_view,)*).get(parent_id).ok().map(|c| {
  224. let (#(#temps,)*) = c;
  225. (#(#temps as *const _,)*)
  226. });
  227. raw_parent
  228. });
  229. }
  230. }
  231. };
  232. let deref_parent_view = {
  233. if parent_dependencies.is_empty() {
  234. quote! {
  235. let parent = raw_parent;
  236. }
  237. } else {
  238. let indexes = (0..parent_dependencies.len()).map(syn::Index::from);
  239. quote! {
  240. let parent = unsafe { raw_parent.map(|raw_parent| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_parent.#indexes),)*)) };
  241. }
  242. }
  243. };
  244. let child_view = child_dependencies_idxes
  245. .iter()
  246. .map(|i| format_ident!("__data{}", i))
  247. .collect::<Vec<_>>();
  248. let get_child_view = {
  249. if child_dependencies.is_empty() {
  250. quote! {
  251. let raw_children: Vec<_> = tree.children_ids_advanced(id, Self::TRAVERSE_SHADOW_DOM).into_iter().map(|_| ()).collect();
  252. }
  253. } else {
  254. let temps = (0..child_dependencies.len())
  255. .map(|i| format_ident!("__temp{}", i))
  256. .collect::<Vec<_>>();
  257. quote! {
  258. let raw_children: Vec<_> = tree.children_ids_advanced(id, Self::TRAVERSE_SHADOW_DOM).into_iter().filter_map(|id| {
  259. let raw_children: Option<(#(*const #child_dependencies,)*)> = (#(&#child_view,)*).get(id).ok().map(|c| {
  260. let (#(#temps,)*) = c;
  261. (#(#temps as *const _,)*)
  262. });
  263. raw_children
  264. }).collect();
  265. }
  266. }
  267. };
  268. let deref_child_view = {
  269. if child_dependencies.is_empty() {
  270. quote! {
  271. let children = raw_children;
  272. }
  273. } else {
  274. let indexes = (0..child_dependencies.len()).map(syn::Index::from);
  275. quote! {
  276. let children = unsafe { raw_children.iter().map(|raw_children| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_children.#indexes),)*)).collect::<Vec<_>>() };
  277. }
  278. }
  279. };
  280. let trait_generics = trait_
  281. .as_ref()
  282. .unwrap()
  283. .segments
  284. .last()
  285. .unwrap()
  286. .arguments
  287. .clone();
  288. // if a create function is defined, we don't generate one
  289. // otherwise we generate a default one that uses the update function and the default constructor
  290. let create_fn = (!has_create_fn).then(|| {
  291. quote! {
  292. fn create<'a>(
  293. node_view: dioxus_native_core::prelude::NodeView # trait_generics,
  294. node: <Self::NodeDependencies as Dependancy>::ElementBorrowed<'a>,
  295. parent: Option<<Self::ParentDependencies as Dependancy>::ElementBorrowed<'a>>,
  296. children: Vec<<Self::ChildDependencies as Dependancy>::ElementBorrowed<'a>>,
  297. context: &dioxus_native_core::prelude::SendAnyMap,
  298. ) -> Self {
  299. let mut myself = Self::default();
  300. myself.update(node_view, node, parent, children, context);
  301. myself
  302. }
  303. }
  304. });
  305. quote!(
  306. #(#attrs)*
  307. #defaultness #unsafety #impl_token #generics #trait_ #for_ #self_ty {
  308. #create_fn
  309. #(#items)*
  310. fn workload_system(type_id: std::any::TypeId, dependants: std::sync::Arc<dioxus_native_core::prelude::Dependants>, pass_direction: dioxus_native_core::prelude::PassDirection) -> dioxus_native_core::exports::shipyard::WorkloadSystem {
  311. use dioxus_native_core::exports::shipyard::{IntoWorkloadSystem, Get, AddComponent};
  312. use dioxus_native_core::tree::TreeRef;
  313. use dioxus_native_core::prelude::{NodeType, NodeView};
  314. let node_mask = Self::NODE_MASK.build();
  315. (move |data: #combined_dependencies_quote, run_view: dioxus_native_core::prelude::RunPassView #trait_generics| {
  316. let (#(#split_views,)*) = data;
  317. let tree = run_view.tree.clone();
  318. let node_types = run_view.node_type.clone();
  319. dioxus_native_core::prelude::run_pass(type_id, dependants.clone(), pass_direction, run_view, |id, context| {
  320. let node_data: &NodeType<_> = node_types.get(id).unwrap_or_else(|err| panic!("Failed to get node type {:?}", err));
  321. // get all of the states from the tree view
  322. // Safety: No node has itself as a parent or child.
  323. let raw_myself: Option<*mut Self> = (&mut #this_view).get(id).ok().map(|c| c as *mut _);
  324. #get_node_view
  325. #get_parent_view
  326. #get_child_view
  327. let myself: Option<&mut Self> = unsafe { raw_myself.map(|val| &mut *val) };
  328. #deref_node_view
  329. #deref_parent_view
  330. #deref_child_view
  331. let view = NodeView::new(id, node_data, &node_mask);
  332. if let Some(myself) = myself {
  333. myself
  334. .update(view, node, parent, children, context)
  335. }
  336. else {
  337. (&mut #this_view).add_component_unchecked(
  338. id,
  339. Self::create(view, node, parent, children, context));
  340. true
  341. }
  342. })
  343. }).into_workload_system().unwrap()
  344. }
  345. }
  346. )
  347. .into()
  348. }
  349. fn extract_tuple(ty: &Type) -> Option<TypeTuple> {
  350. match ty {
  351. Type::Tuple(tuple) => Some(tuple.clone()),
  352. Type::Group(group) => extract_tuple(&group.elem),
  353. _ => None,
  354. }
  355. }
  356. fn extract_type_path(ty: &Type) -> Option<TypePath> {
  357. match ty {
  358. Type::Path(path) => Some(path.clone()),
  359. Type::Group(group) => extract_type_path(&group.elem),
  360. _ => None,
  361. }
  362. }