1
0

segment.rs 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. use quote::{format_ident, quote};
  2. use syn::{Ident, Type};
  3. use proc_macro2::{Span, TokenStream as TokenStream2};
  4. use crate::query::QuerySegment;
  5. #[derive(Debug, Clone)]
  6. pub enum RouteSegment {
  7. Static(String),
  8. Dynamic(Ident, Type),
  9. CatchAll(Ident, Type),
  10. }
  11. impl RouteSegment {
  12. pub fn name(&self) -> Option<Ident> {
  13. match self {
  14. Self::Static(_) => None,
  15. Self::Dynamic(ident, _) => Some(ident.clone()),
  16. Self::CatchAll(ident, _) => Some(ident.clone()),
  17. }
  18. }
  19. pub fn write_segment(&self) -> TokenStream2 {
  20. match self {
  21. Self::Static(segment) => quote! { write!(f, "/{}", #segment)?; },
  22. Self::Dynamic(ident, _) => quote! { write!(f, "/{}", #ident)?; },
  23. Self::CatchAll(ident, _) => quote! { #ident.display_route_segements(f)?; },
  24. }
  25. }
  26. pub fn error_name(&self, idx: usize) -> Ident {
  27. match self {
  28. Self::Static(_) => static_segment_idx(idx),
  29. Self::Dynamic(ident, _) => format_ident!("{}ParseError", ident),
  30. Self::CatchAll(ident, _) => format_ident!("{}ParseError", ident),
  31. }
  32. }
  33. pub fn missing_error_name(&self) -> Option<Ident> {
  34. match self {
  35. Self::Dynamic(ident, _) => Some(format_ident!("{}MissingError", ident)),
  36. _ => None,
  37. }
  38. }
  39. pub fn try_parse(
  40. &self,
  41. idx: usize,
  42. error_enum_name: &Ident,
  43. error_enum_varient: &Ident,
  44. inner_parse_enum: &Ident,
  45. parse_children: TokenStream2,
  46. ) -> TokenStream2 {
  47. let error_name = self.error_name(idx);
  48. match self {
  49. Self::Static(segment) => {
  50. quote! {
  51. {
  52. let mut segments = segments.clone();
  53. let parsed = if let Some(#segment) = segments.next() {
  54. Ok(())
  55. } else {
  56. Err(#error_enum_name::#error_enum_varient(#inner_parse_enum::#error_name))
  57. };
  58. match parsed {
  59. Ok(_) => {
  60. #parse_children
  61. }
  62. Err(err) => {
  63. errors.push(err);
  64. }
  65. }
  66. }
  67. }
  68. }
  69. Self::Dynamic(name, ty) => {
  70. let missing_error_name = self.missing_error_name().unwrap();
  71. quote! {
  72. {
  73. let mut segments = segments.clone();
  74. let parsed = if let Some(segment) = segments.next() {
  75. <#ty as dioxus_router::routable::FromRouteSegment>::from_route_segment(segment).map_err(|err| #error_enum_name::#error_enum_varient(#inner_parse_enum::#error_name(err)))
  76. } else {
  77. Err(#error_enum_name::#error_enum_varient(#inner_parse_enum::#missing_error_name))
  78. };
  79. match parsed {
  80. Ok(#name) => {
  81. #parse_children
  82. }
  83. Err(err) => {
  84. errors.push(err);
  85. }
  86. }
  87. }
  88. }
  89. }
  90. Self::CatchAll(name, ty) => {
  91. quote! {
  92. {
  93. let parsed = {
  94. let mut segments = segments.clone();
  95. let segments: Vec<_> = segments.collect();
  96. <#ty as dioxus_router::routable::FromRouteSegments>::from_route_segments(&segments).map_err(|err| #error_enum_name::#error_enum_varient(#inner_parse_enum::#error_name(err)))
  97. };
  98. match parsed {
  99. Ok(#name) => {
  100. #parse_children
  101. }
  102. Err(err) => {
  103. errors.push(err);
  104. }
  105. }
  106. }
  107. }
  108. }
  109. }
  110. }
  111. }
  112. pub fn static_segment_idx(idx: usize) -> Ident {
  113. format_ident!("StaticSegment{}ParseError", idx)
  114. }
  115. pub fn parse_route_segments<'a>(
  116. route_span: Span,
  117. mut fields: impl Iterator<Item = (&'a Ident, &'a Type)>,
  118. route: &str,
  119. ) -> syn::Result<(Vec<RouteSegment>, Option<QuerySegment>)> {
  120. let mut route_segments = Vec::new();
  121. let (route_string, query) = match route.rsplit_once('?') {
  122. Some((route, query)) => (route, Some(query)),
  123. None => (route, None),
  124. };
  125. let mut iterator = route_string.split('/');
  126. // skip the first empty segment
  127. let first = iterator.next();
  128. if first != Some("") {
  129. return Err(syn::Error::new(
  130. route_span,
  131. format!(
  132. "Routes should start with /. Error found in the route '{}'",
  133. route
  134. ),
  135. ));
  136. }
  137. while let Some(segment) = iterator.next() {
  138. if let Some(segment) = segment.strip_prefix(':') {
  139. let spread = segment.starts_with("...");
  140. let ident = if spread {
  141. segment[3..].to_string()
  142. } else {
  143. segment.to_string()
  144. };
  145. let field = fields.find(|(name, _)| **name == ident);
  146. let ty = if let Some(field) = field {
  147. field.1.clone()
  148. } else {
  149. return Err(syn::Error::new(
  150. route_span,
  151. format!("Could not find a field with the name '{}'", ident,),
  152. ));
  153. };
  154. if spread {
  155. route_segments.push(RouteSegment::CatchAll(
  156. Ident::new(&ident, Span::call_site()),
  157. ty,
  158. ));
  159. if iterator.next().is_some() {
  160. return Err(syn::Error::new(
  161. route_span,
  162. "Catch-all route segments must be the last segment in a route. The route segments after the catch-all segment will never be matched.",
  163. ));
  164. } else {
  165. break;
  166. }
  167. } else {
  168. route_segments.push(RouteSegment::Dynamic(
  169. Ident::new(&ident, Span::call_site()),
  170. ty,
  171. ));
  172. }
  173. } else {
  174. route_segments.push(RouteSegment::Static(segment.to_string()));
  175. }
  176. }
  177. // check if the route has a query string
  178. let parsed_query = match query {
  179. Some(query) => {
  180. if let Some(query) = query.strip_prefix(':') {
  181. let query_ident = Ident::new(query, Span::call_site());
  182. let field = fields.find(|(name, _)| *name == &query_ident);
  183. let ty = if let Some((_, ty)) = field {
  184. ty.clone()
  185. } else {
  186. return Err(syn::Error::new(
  187. route_span,
  188. format!("Could not find a field with the name '{}'", query_ident),
  189. ));
  190. };
  191. Some(QuerySegment {
  192. ident: query_ident,
  193. ty,
  194. })
  195. } else {
  196. None
  197. }
  198. }
  199. None => None,
  200. };
  201. Ok((route_segments, parsed_query))
  202. }
  203. pub(crate) fn create_error_type(error_name: Ident, segments: &[RouteSegment]) -> TokenStream2 {
  204. let mut error_variants = Vec::new();
  205. let mut display_match = Vec::new();
  206. for (i, segment) in segments.iter().enumerate() {
  207. let error_name = segment.error_name(i);
  208. match segment {
  209. RouteSegment::Static(index) => {
  210. error_variants.push(quote! { #error_name });
  211. display_match.push(quote! { Self::#error_name => write!(f, "Static segment '{}' did not match", #index)? });
  212. }
  213. RouteSegment::Dynamic(ident, ty) => {
  214. let missing_error = segment.missing_error_name().unwrap();
  215. error_variants.push(
  216. quote! { #error_name(<#ty as dioxus_router::routable::FromRouteSegment>::Err) },
  217. );
  218. display_match.push(quote! { Self::#error_name(err) => write!(f, "Dynamic segment '({}:{})' did not match: {}", stringify!(#ident), stringify!(#ty), err)? });
  219. error_variants.push(quote! { #missing_error });
  220. display_match.push(quote! { Self::#missing_error => write!(f, "Dynamic segment '({}:{})' was missing", stringify!(#ident), stringify!(#ty))? });
  221. }
  222. RouteSegment::CatchAll(ident, ty) => {
  223. error_variants.push(quote! { #error_name(<#ty as dioxus_router::routable::FromRouteSegments>::Err) });
  224. display_match.push(quote! { Self::#error_name(err) => write!(f, "Catch-all segment '({}:{})' did not match: {}", stringify!(#ident), stringify!(#ty), err)? });
  225. }
  226. }
  227. }
  228. quote! {
  229. #[allow(non_camel_case_types)]
  230. #[derive(Debug, PartialEq)]
  231. pub enum #error_name {
  232. ExtraSegments(String),
  233. #(#error_variants,)*
  234. }
  235. impl std::fmt::Display for #error_name {
  236. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  237. match self {
  238. Self::ExtraSegments(segments) => {
  239. write!(f, "Found additional trailing segments: {}", segments)?
  240. }
  241. #(#display_match,)*
  242. }
  243. Ok(())
  244. }
  245. }
  246. }
  247. }