1
0

segment.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. use quote::{format_ident, quote};
  2. use syn::{Ident, Type};
  3. use proc_macro2::{Span, TokenStream as TokenStream2};
  4. use crate::query::{FullQuerySegment, QueryArgument, QuerySegment};
  5. #[derive(Debug, Clone)]
  6. pub enum RouteSegment {
  7. Static(String),
  8. Dynamic(Ident, Type),
  9. CatchAll(Ident, Type),
  10. }
  11. impl RouteSegment {
  12. pub fn name(&self) -> Option<Ident> {
  13. match self {
  14. Self::Static(_) => None,
  15. Self::Dynamic(ident, _) => Some(ident.clone()),
  16. Self::CatchAll(ident, _) => Some(ident.clone()),
  17. }
  18. }
  19. pub fn write_segment(&self) -> TokenStream2 {
  20. match self {
  21. Self::Static(segment) => quote! { write!(f, "/{}", #segment)?; },
  22. Self::Dynamic(ident, _) => quote! { write!(f, "/{}", #ident)?; },
  23. Self::CatchAll(ident, _) => quote! { #ident.display_route_segments(f)?; },
  24. }
  25. }
  26. pub fn error_name(&self, idx: usize) -> Ident {
  27. match self {
  28. Self::Static(_) => static_segment_idx(idx),
  29. Self::Dynamic(ident, _) => format_ident!("{}ParseError", ident),
  30. Self::CatchAll(ident, _) => format_ident!("{}ParseError", ident),
  31. }
  32. }
  33. pub fn missing_error_name(&self) -> Option<Ident> {
  34. match self {
  35. Self::Dynamic(ident, _) => Some(format_ident!("{}MissingError", ident)),
  36. _ => None,
  37. }
  38. }
  39. pub fn try_parse(
  40. &self,
  41. idx: usize,
  42. error_enum_name: &Ident,
  43. error_enum_variant: &Ident,
  44. inner_parse_enum: &Ident,
  45. parse_children: TokenStream2,
  46. ) -> TokenStream2 {
  47. let error_name = self.error_name(idx);
  48. match self {
  49. Self::Static(segment) => {
  50. quote! {
  51. {
  52. let mut segments = segments.clone();
  53. let segment = segments.next();
  54. let segment = segment.as_deref();
  55. let parsed = if let Some(#segment) = segment {
  56. Ok(())
  57. } else {
  58. Err(#error_enum_name::#error_enum_variant(#inner_parse_enum::#error_name(segment.map(|s|s.to_string()).unwrap_or_default())))
  59. };
  60. match parsed {
  61. Ok(_) => {
  62. #parse_children
  63. }
  64. Err(err) => {
  65. errors.push(err);
  66. }
  67. }
  68. }
  69. }
  70. }
  71. Self::Dynamic(name, ty) => {
  72. let missing_error_name = self.missing_error_name().unwrap();
  73. quote! {
  74. {
  75. let mut segments = segments.clone();
  76. let segment = segments.next();
  77. let parsed = if let Some(segment) = segment.as_deref() {
  78. <#ty as dioxus_router::routable::FromRouteSegment>::from_route_segment(segment).map_err(|err| #error_enum_name::#error_enum_variant(#inner_parse_enum::#error_name(err)))
  79. } else {
  80. Err(#error_enum_name::#error_enum_variant(#inner_parse_enum::#missing_error_name))
  81. };
  82. match parsed {
  83. Ok(#name) => {
  84. #parse_children
  85. }
  86. Err(err) => {
  87. errors.push(err);
  88. }
  89. }
  90. }
  91. }
  92. }
  93. Self::CatchAll(name, ty) => {
  94. quote! {
  95. {
  96. let parsed = {
  97. let remaining_segments: Vec<_> = segments.collect();
  98. let mut new_segments: Vec<&str> = Vec::new();
  99. for segment in &remaining_segments {
  100. new_segments.push(&*segment);
  101. }
  102. <#ty as dioxus_router::routable::FromRouteSegments>::from_route_segments(&new_segments).map_err(|err| #error_enum_name::#error_enum_variant(#inner_parse_enum::#error_name(err)))
  103. };
  104. match parsed {
  105. Ok(#name) => {
  106. #parse_children
  107. }
  108. Err(err) => {
  109. errors.push(err);
  110. }
  111. }
  112. }
  113. }
  114. }
  115. }
  116. }
  117. }
  118. pub fn static_segment_idx(idx: usize) -> Ident {
  119. format_ident!("StaticSegment{}ParseError", idx)
  120. }
  121. pub fn parse_route_segments<'a>(
  122. route_span: Span,
  123. fields: impl Iterator<Item = (&'a Ident, &'a Type)> + Clone,
  124. route: &str,
  125. ) -> syn::Result<(Vec<RouteSegment>, Option<QuerySegment>)> {
  126. let mut route_segments = Vec::new();
  127. let (route_string, query) = match route.rsplit_once('?') {
  128. Some((route, query)) => (route, Some(query)),
  129. None => (route, None),
  130. };
  131. let mut iterator = route_string.split('/');
  132. // skip the first empty segment
  133. let first = iterator.next();
  134. if first != Some("") {
  135. return Err(syn::Error::new(
  136. route_span,
  137. format!(
  138. "Routes should start with /. Error found in the route '{}'",
  139. route
  140. ),
  141. ));
  142. }
  143. while let Some(segment) = iterator.next() {
  144. if let Some(segment) = segment.strip_prefix(':') {
  145. let spread = segment.starts_with("..");
  146. let ident = if spread {
  147. segment[2..].to_string()
  148. } else {
  149. segment.to_string()
  150. };
  151. let field = fields.clone().find(|(name, _)| **name == ident);
  152. let ty = if let Some(field) = field {
  153. field.1.clone()
  154. } else {
  155. return Err(syn::Error::new(
  156. route_span,
  157. format!("Could not find a field with the name '{}'", ident,),
  158. ));
  159. };
  160. if spread {
  161. route_segments.push(RouteSegment::CatchAll(
  162. Ident::new(&ident, Span::call_site()),
  163. ty,
  164. ));
  165. if iterator.next().is_some() {
  166. return Err(syn::Error::new(
  167. route_span,
  168. "Catch-all route segments must be the last segment in a route. The route segments after the catch-all segment will never be matched.",
  169. ));
  170. } else {
  171. break;
  172. }
  173. } else {
  174. route_segments.push(RouteSegment::Dynamic(
  175. Ident::new(&ident, Span::call_site()),
  176. ty,
  177. ));
  178. }
  179. } else {
  180. route_segments.push(RouteSegment::Static(segment.to_string()));
  181. }
  182. }
  183. // check if the route has a query string
  184. let parsed_query = match query {
  185. Some(query) => {
  186. if let Some(query) = query.strip_prefix(":..") {
  187. let query_ident = Ident::new(query, Span::call_site());
  188. let field = fields.clone().find(|(name, _)| *name == &query_ident);
  189. let ty = if let Some((_, ty)) = field {
  190. ty.clone()
  191. } else {
  192. return Err(syn::Error::new(
  193. route_span,
  194. format!("Could not find a field with the name '{}'", query_ident),
  195. ));
  196. };
  197. Some(QuerySegment::Single(FullQuerySegment {
  198. ident: query_ident,
  199. ty,
  200. }))
  201. } else {
  202. let mut query_arguments = Vec::new();
  203. for segment in query.split('&') {
  204. if segment.is_empty() {
  205. return Err(syn::Error::new(
  206. route_span,
  207. "Query segments should be non-empty",
  208. ));
  209. }
  210. if let Some(query_argument) = segment.strip_prefix(':') {
  211. let query_ident = Ident::new(query_argument, Span::call_site());
  212. let field = fields.clone().find(|(name, _)| *name == &query_ident);
  213. let ty = if let Some((_, ty)) = field {
  214. ty.clone()
  215. } else {
  216. return Err(syn::Error::new(
  217. route_span,
  218. format!("Could not find a field with the name '{}'", query_ident),
  219. ));
  220. };
  221. query_arguments.push(QueryArgument {
  222. ident: query_ident,
  223. ty,
  224. });
  225. } else {
  226. return Err(syn::Error::new(
  227. route_span,
  228. "Query segments should be a : followed by the name of the query argument",
  229. ));
  230. }
  231. }
  232. Some(QuerySegment::Segments(query_arguments))
  233. }
  234. }
  235. None => None,
  236. };
  237. Ok((route_segments, parsed_query))
  238. }
  239. pub(crate) fn create_error_type(
  240. error_name: Ident,
  241. segments: &[RouteSegment],
  242. child_type: Option<&Type>,
  243. ) -> TokenStream2 {
  244. let mut error_variants = Vec::new();
  245. let mut display_match = Vec::new();
  246. for (i, segment) in segments.iter().enumerate() {
  247. let error_name = segment.error_name(i);
  248. match segment {
  249. RouteSegment::Static(index) => {
  250. error_variants.push(quote! { #error_name(String) });
  251. display_match.push(quote! { Self::#error_name(found) => write!(f, "Static segment '{}' did not match instead found '{}'", #index, found)? });
  252. }
  253. RouteSegment::Dynamic(ident, ty) => {
  254. let missing_error = segment.missing_error_name().unwrap();
  255. error_variants.push(
  256. quote! { #error_name(<#ty as dioxus_router::routable::FromRouteSegment>::Err) },
  257. );
  258. display_match.push(quote! { Self::#error_name(err) => write!(f, "Dynamic segment '({}:{})' did not match: {}", stringify!(#ident), stringify!(#ty), err)? });
  259. error_variants.push(quote! { #missing_error });
  260. display_match.push(quote! { Self::#missing_error => write!(f, "Dynamic segment '({}:{})' was missing", stringify!(#ident), stringify!(#ty))? });
  261. }
  262. RouteSegment::CatchAll(ident, ty) => {
  263. error_variants.push(quote! { #error_name(<#ty as dioxus_router::routable::FromRouteSegments>::Err) });
  264. display_match.push(quote! { Self::#error_name(err) => write!(f, "Catch-all segment '({}:{})' did not match: {}", stringify!(#ident), stringify!(#ty), err)? });
  265. }
  266. }
  267. }
  268. let child_type_variant = child_type
  269. .map(|child_type| {
  270. quote! { ChildRoute(<#child_type as std::str::FromStr>::Err) }
  271. })
  272. .into_iter();
  273. let child_type_error = child_type
  274. .map(|_| {
  275. quote! {
  276. Self::ChildRoute(error) => {
  277. write!(f, "{}", error)?
  278. }
  279. }
  280. })
  281. .into_iter();
  282. quote! {
  283. #[allow(non_camel_case_types)]
  284. #[allow(clippy::derive_partial_eq_without_eq)]
  285. #[derive(Debug, PartialEq)]
  286. pub enum #error_name {
  287. ExtraSegments(String),
  288. #(#child_type_variant,)*
  289. #(#error_variants,)*
  290. }
  291. impl std::fmt::Display for #error_name {
  292. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  293. match self {
  294. Self::ExtraSegments(segments) => {
  295. write!(f, "Found additional trailing segments: {}", segments)?
  296. },
  297. #(#child_type_error,)*
  298. #(#display_match,)*
  299. }
  300. Ok(())
  301. }
  302. }
  303. }
  304. }