lib.rs 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. extern crate proc_macro;
  2. use nest::{Layout, Nest};
  3. use proc_macro::TokenStream;
  4. use quote::{__private::Span, format_ident, quote, ToTokens};
  5. use route::Route;
  6. use syn::{parse_macro_input, Ident};
  7. use proc_macro2::TokenStream as TokenStream2;
  8. use crate::{nest::LayoutId, route_tree::RouteTree};
  9. mod nest;
  10. mod query;
  11. mod route;
  12. mod route_tree;
  13. mod segment;
  14. // #[proc_macro_derive(Routable, attributes(route, nest, end_nest))]
  15. #[proc_macro_attribute]
  16. pub fn routable(_: TokenStream, input: TokenStream) -> TokenStream {
  17. let routes_enum = parse_macro_input!(input as syn::ItemEnum);
  18. let route_enum = match RouteEnum::parse(routes_enum) {
  19. Ok(route_enum) => route_enum,
  20. Err(err) => return err.to_compile_error().into(),
  21. };
  22. let error_type = route_enum.error_type();
  23. let parse_impl = route_enum.parse_impl();
  24. let display_impl = route_enum.impl_display();
  25. let routable_impl = route_enum.routable_impl();
  26. quote! {
  27. #route_enum
  28. #error_type
  29. #parse_impl
  30. #display_impl
  31. #routable_impl
  32. }
  33. .into()
  34. }
  35. struct RouteEnum {
  36. vis: syn::Visibility,
  37. attrs: Vec<syn::Attribute>,
  38. name: Ident,
  39. routes: Vec<Route>,
  40. layouts: Vec<Layout>,
  41. }
  42. impl RouteEnum {
  43. fn parse(data: syn::ItemEnum) -> syn::Result<Self> {
  44. let name = &data.ident;
  45. enum NestRef {
  46. Static(String),
  47. Dynamic { id: LayoutId },
  48. }
  49. let mut routes = Vec::new();
  50. let mut layouts = Vec::new();
  51. let mut nest_stack = Vec::new();
  52. for variant in data.variants {
  53. // Apply the any nesting attributes in order
  54. for attr in &variant.attrs {
  55. if attr.path.is_ident("nest") {
  56. let nest: Nest = attr.parse_args()?;
  57. let nest_ref = match nest {
  58. Nest::Static(s) => NestRef::Static(s),
  59. Nest::Layout(mut l) => {
  60. // if there is a static nest before this, add it to the layout
  61. let mut static_prefix = nest_stack
  62. .iter()
  63. // walk backwards and take all static nests
  64. .rev()
  65. .map_while(|nest| match nest {
  66. NestRef::Static(s) => Some(s.clone()),
  67. NestRef::Dynamic { .. } => None,
  68. })
  69. .collect::<Vec<_>>();
  70. // reverse the static prefix so it is in the correct order
  71. static_prefix.reverse();
  72. if !static_prefix.is_empty() {
  73. l.add_static_prefix(&static_prefix.join("/"));
  74. }
  75. let id = layouts.len();
  76. layouts.push(l);
  77. NestRef::Dynamic { id: LayoutId(id) }
  78. }
  79. };
  80. nest_stack.push(nest_ref);
  81. } else if attr.path.is_ident("end_nest") {
  82. nest_stack.pop();
  83. }
  84. }
  85. let mut trailing_static_route = nest_stack
  86. .iter()
  87. .rev()
  88. .map_while(|nest| match nest {
  89. NestRef::Static(s) => Some(s.clone()),
  90. NestRef::Dynamic { .. } => None,
  91. })
  92. .collect::<Vec<_>>();
  93. trailing_static_route.reverse();
  94. let active_layouts = nest_stack
  95. .iter()
  96. .filter_map(|nest| match nest {
  97. NestRef::Static(_) => None,
  98. NestRef::Dynamic { id } => Some(*id),
  99. })
  100. .collect::<Vec<_>>();
  101. let route = Route::parse(trailing_static_route.join("/"), active_layouts, variant)?;
  102. routes.push(route);
  103. }
  104. let myself = Self {
  105. vis: data.vis,
  106. attrs: data.attrs,
  107. name: name.clone(),
  108. routes,
  109. layouts,
  110. };
  111. Ok(myself)
  112. }
  113. fn impl_display(&self) -> TokenStream2 {
  114. let mut display_match = Vec::new();
  115. for route in &self.routes {
  116. display_match.push(route.display_match(&self.layouts));
  117. }
  118. let name = &self.name;
  119. quote! {
  120. impl std::fmt::Display for #name {
  121. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  122. match self {
  123. #(#display_match)*
  124. }
  125. Ok(())
  126. }
  127. }
  128. }
  129. }
  130. fn parse_impl(&self) -> TokenStream2 {
  131. let tree = RouteTree::new(&self.routes, &self.layouts);
  132. let name = &self.name;
  133. let error_name = format_ident!("{}MatchError", self.name);
  134. let tokens = tree.roots.iter().map(|&id| {
  135. let route = tree.get(id).unwrap();
  136. route.to_tokens(&tree, self.name.clone(), error_name.clone(), &self.layouts)
  137. });
  138. quote! {
  139. impl<'a> TryFrom<&'a str> for #name {
  140. type Error = <Self as std::str::FromStr>::Err;
  141. fn try_from(s: &'a str) -> Result<Self, Self::Error> {
  142. s.parse()
  143. }
  144. }
  145. impl std::str::FromStr for #name {
  146. type Err = RouteParseError<#error_name>;
  147. fn from_str(s: &str) -> Result<Self, Self::Err> {
  148. let route = s.strip_prefix('/').unwrap_or(s);
  149. let (route, query) = route.split_once('?').unwrap_or((route, ""));
  150. let mut segments = route.split('/');
  151. let mut errors = Vec::new();
  152. #(#tokens)*
  153. Err(RouteParseError {
  154. attempted_routes: errors,
  155. })
  156. }
  157. }
  158. }
  159. }
  160. fn error_name(&self) -> Ident {
  161. Ident::new(&(self.name.to_string() + "MatchError"), Span::call_site())
  162. }
  163. fn error_type(&self) -> TokenStream2 {
  164. let match_error_name = self.error_name();
  165. let mut type_defs = Vec::new();
  166. let mut error_variants = Vec::new();
  167. let mut display_match = Vec::new();
  168. for route in &self.routes {
  169. let route_name = &route.route_name;
  170. let error_name = route.error_ident();
  171. let route_str = &route.route;
  172. error_variants.push(quote! { #route_name(#error_name) });
  173. display_match.push(quote! { Self::#route_name(err) => write!(f, "Route '{}' ('{}') did not match:\n{}", stringify!(#route_name), #route_str, err)? });
  174. type_defs.push(route.error_type());
  175. }
  176. for layout in &self.layouts {
  177. let layout_name = &layout.layout_name;
  178. let error_name = layout.error_ident();
  179. let route_str = &layout.route;
  180. error_variants.push(quote! { #layout_name(#error_name) });
  181. display_match.push(quote! { Self::#layout_name(err) => write!(f, "Layout '{}' ('{}') did not match:\n{}", stringify!(#layout_name), #route_str, err)? });
  182. type_defs.push(layout.error_type());
  183. }
  184. quote! {
  185. #(#type_defs)*
  186. #[allow(non_camel_case_types)]
  187. #[derive(Debug, PartialEq)]
  188. pub enum #match_error_name {
  189. #(#error_variants),*
  190. }
  191. impl std::fmt::Display for #match_error_name {
  192. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  193. match self {
  194. #(#display_match),*
  195. }
  196. Ok(())
  197. }
  198. }
  199. }
  200. }
  201. fn routable_impl(&self) -> TokenStream2 {
  202. let name = &self.name;
  203. let mut layers = Vec::new();
  204. loop {
  205. let index = layers.len();
  206. let mut routable_match = Vec::new();
  207. // Collect all routes that match the current layer
  208. for route in &self.routes {
  209. if let Some(matched) = route.routable_match(&self.layouts, index) {
  210. routable_match.push(matched);
  211. }
  212. }
  213. // All routes are exhausted
  214. if routable_match.is_empty() {
  215. break;
  216. }
  217. layers.push(quote! {
  218. #(#routable_match)*
  219. });
  220. }
  221. let index_iter = 0..layers.len();
  222. quote! {
  223. impl Routable for #name where Self: Clone {
  224. fn render<'a>(&self, cx: &'a ScopeState, level: usize) -> Element<'a> {
  225. let myself = self.clone();
  226. match level {
  227. #(
  228. #index_iter => {
  229. match myself {
  230. #layers
  231. _ => panic!("Route::render called with invalid level {}", level),
  232. }
  233. },
  234. )*
  235. _ => panic!("Route::render called with invalid level {}", level),
  236. }
  237. }
  238. }
  239. }
  240. }
  241. }
  242. impl ToTokens for RouteEnum {
  243. fn to_tokens(&self, tokens: &mut quote::__private::TokenStream) {
  244. let routes = &self.routes;
  245. let vis = &self.vis;
  246. let name = &self.name;
  247. let attrs = &self.attrs;
  248. let variants = routes.iter().map(|r| r.variant(&self.layouts));
  249. tokens.extend(quote!(
  250. #(#attrs)*
  251. #vis enum #name {
  252. #(#variants),*
  253. }
  254. #[path = "pages"]
  255. mod pages {
  256. #(#routes)*
  257. }
  258. pub use pages::*;
  259. ));
  260. }
  261. }