route_tree.rs 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. use proc_macro2::TokenStream;
  2. use quote::quote;
  3. use slab::Slab;
  4. use syn::Ident;
  5. use crate::{
  6. nest::{Nest, NestId},
  7. redirect::Redirect,
  8. route::{Route, RouteType},
  9. segment::{static_segment_idx, RouteSegment},
  10. RouteEndpoint,
  11. };
  12. #[derive(Debug, Clone, Default)]
  13. pub(crate) struct ParseRouteTree<'a> {
  14. pub roots: Vec<usize>,
  15. entries: Slab<RouteTreeSegmentData<'a>>,
  16. }
  17. impl<'a> ParseRouteTree<'a> {
  18. pub fn get(&self, index: usize) -> Option<&RouteTreeSegmentData<'a>> {
  19. self.entries.get(index)
  20. }
  21. pub fn get_mut(&mut self, element: usize) -> Option<&mut RouteTreeSegmentData<'a>> {
  22. self.entries.get_mut(element)
  23. }
  24. fn sort_children(&mut self) {
  25. let mut old_roots = self.roots.clone();
  26. self.sort_ids(&mut old_roots);
  27. self.roots = old_roots;
  28. for id in self.roots.clone() {
  29. self.sort_children_of_id(id);
  30. }
  31. }
  32. fn sort_ids(&self, ids: &mut [usize]) {
  33. ids.sort_by_key(|&seg| {
  34. let seg = self.get(seg).unwrap();
  35. match seg {
  36. RouteTreeSegmentData::Static { .. } => 0,
  37. RouteTreeSegmentData::Nest { .. } => 1,
  38. RouteTreeSegmentData::Route(route) => {
  39. // Routes that end in a catch all segment should be checked last
  40. match route.segments.last() {
  41. Some(RouteSegment::CatchAll(..)) => 2,
  42. _ => 1,
  43. }
  44. }
  45. RouteTreeSegmentData::Redirect(redirect) => {
  46. // Routes that end in a catch all segment should be checked last
  47. match redirect.segments.last() {
  48. Some(RouteSegment::CatchAll(..)) => 2,
  49. _ => 1,
  50. }
  51. }
  52. }
  53. });
  54. }
  55. fn sort_children_of_id(&mut self, id: usize) {
  56. // Sort segments so that all static routes are checked before dynamic routes
  57. let mut children = self.children(id);
  58. self.sort_ids(&mut children);
  59. if let Some(old) = self.try_children_mut(id) {
  60. old.clone_from(&children)
  61. }
  62. for id in children {
  63. self.sort_children_of_id(id);
  64. }
  65. }
  66. fn children(&self, element: usize) -> Vec<usize> {
  67. let element = self.entries.get(element).unwrap();
  68. match element {
  69. RouteTreeSegmentData::Static { children, .. } => children.clone(),
  70. RouteTreeSegmentData::Nest { children, .. } => children.clone(),
  71. _ => Vec::new(),
  72. }
  73. }
  74. fn try_children_mut(&mut self, element: usize) -> Option<&mut Vec<usize>> {
  75. let element = self.entries.get_mut(element).unwrap();
  76. match element {
  77. RouteTreeSegmentData::Static { children, .. } => Some(children),
  78. RouteTreeSegmentData::Nest { children, .. } => Some(children),
  79. _ => None,
  80. }
  81. }
  82. fn children_mut(&mut self, element: usize) -> &mut Vec<usize> {
  83. self.try_children_mut(element)
  84. .expect("Cannot get children of non static or nest segment")
  85. }
  86. pub(crate) fn new(endpoints: &'a [RouteEndpoint], nests: &'a [Nest]) -> Self {
  87. let routes = endpoints
  88. .iter()
  89. .map(|endpoint| match endpoint {
  90. RouteEndpoint::Route(route) => PathIter::new_route(route, nests),
  91. RouteEndpoint::Redirect(redirect) => PathIter::new_redirect(redirect, nests),
  92. })
  93. .collect::<Vec<_>>();
  94. let mut myself = Self::default();
  95. myself.roots = myself.construct(routes);
  96. myself.sort_children();
  97. myself
  98. }
  99. pub fn construct(&mut self, routes: Vec<PathIter<'a>>) -> Vec<usize> {
  100. let mut segments = Vec::new();
  101. // Add all routes to the tree
  102. for mut route in routes {
  103. let mut current_route: Option<usize> = None;
  104. // First add all nests
  105. while let Some(nest) = route.next_nest() {
  106. let segments_iter = nest.segments.iter();
  107. // Add all static segments of the nest
  108. 'o: for (index, segment) in segments_iter.enumerate() {
  109. match segment {
  110. RouteSegment::Static(segment) => {
  111. // Check if the segment already exists
  112. {
  113. // Either look for the segment in the current route or in the static segments
  114. let segments = current_route
  115. .map(|id| self.children(id))
  116. .unwrap_or_else(|| segments.clone());
  117. for &seg_id in segments.iter() {
  118. let seg = self.get(seg_id).unwrap();
  119. if let RouteTreeSegmentData::Static { segment: s, .. } = seg {
  120. if *s == segment {
  121. // If it does, just update the current route
  122. current_route = Some(seg_id);
  123. continue 'o;
  124. }
  125. }
  126. }
  127. }
  128. let static_segment = RouteTreeSegmentData::Static {
  129. segment,
  130. children: Vec::new(),
  131. error_variant: StaticErrorVariant {
  132. variant_parse_error: nest.error_ident(),
  133. enum_variant: nest.error_variant(),
  134. },
  135. index,
  136. };
  137. // If it doesn't, add the segment to the current route
  138. let static_segment = self.entries.insert(static_segment);
  139. let current_children = current_route
  140. .map(|id| self.children_mut(id))
  141. .unwrap_or_else(|| &mut segments);
  142. current_children.push(static_segment);
  143. // Update the current route
  144. current_route = Some(static_segment);
  145. }
  146. // If there is a dynamic segment, stop adding static segments
  147. RouteSegment::Dynamic(..) => break,
  148. RouteSegment::CatchAll(..) => {
  149. todo!("Catch all segments are not allowed in nests")
  150. }
  151. }
  152. }
  153. // Add the nest to the current route
  154. let nest = RouteTreeSegmentData::Nest {
  155. nest,
  156. children: Vec::new(),
  157. };
  158. let nest = self.entries.insert(nest);
  159. let segments = match current_route.and_then(|id| self.get_mut(id)) {
  160. Some(RouteTreeSegmentData::Static { children, .. }) => children,
  161. Some(RouteTreeSegmentData::Nest { children, .. }) => children,
  162. Some(r) => {
  163. unreachable!("{current_route:?}\n{r:?} is not a static or nest segment",)
  164. }
  165. None => &mut segments,
  166. };
  167. segments.push(nest);
  168. // Update the current route
  169. current_route = Some(nest);
  170. }
  171. match route.next_static_segment() {
  172. // If there is a static segment, check if it already exists in the tree
  173. Some((i, segment)) => {
  174. let current_children = current_route
  175. .map(|id| self.children(id))
  176. .unwrap_or_else(|| segments.clone());
  177. let found = current_children.iter().find_map(|&id| {
  178. let seg = self.get(id).unwrap();
  179. match seg {
  180. RouteTreeSegmentData::Static { segment: s, .. } => {
  181. (s == &segment).then_some(id)
  182. }
  183. _ => None,
  184. }
  185. });
  186. match found {
  187. Some(id) => {
  188. // If it exists, add the route to the children of the segment
  189. let new_children = self.construct(vec![route]);
  190. self.children_mut(id).extend(new_children);
  191. }
  192. None => {
  193. // If it doesn't exist, add the route as a new segment
  194. let data = RouteTreeSegmentData::Static {
  195. segment,
  196. error_variant: route.error_variant(),
  197. children: self.construct(vec![route]),
  198. index: i,
  199. };
  200. let id = self.entries.insert(data);
  201. let current_children_mut = current_route
  202. .map(|id| self.children_mut(id))
  203. .unwrap_or_else(|| &mut segments);
  204. current_children_mut.push(id);
  205. }
  206. }
  207. }
  208. // If there is no static segment, add the route to the current_route
  209. None => {
  210. let id = self.entries.insert(route.final_segment);
  211. let current_children_mut = current_route
  212. .map(|id| self.children_mut(id))
  213. .unwrap_or_else(|| &mut segments);
  214. current_children_mut.push(id);
  215. }
  216. }
  217. }
  218. segments
  219. }
  220. }
  221. #[derive(Debug, Clone)]
  222. pub struct StaticErrorVariant {
  223. variant_parse_error: Ident,
  224. enum_variant: Ident,
  225. }
  226. // First deduplicate the routes by the static part of the route
  227. #[derive(Debug, Clone)]
  228. pub(crate) enum RouteTreeSegmentData<'a> {
  229. Static {
  230. segment: &'a str,
  231. error_variant: StaticErrorVariant,
  232. index: usize,
  233. children: Vec<usize>,
  234. },
  235. Nest {
  236. nest: &'a Nest,
  237. children: Vec<usize>,
  238. },
  239. Route(&'a Route),
  240. Redirect(&'a Redirect),
  241. }
  242. impl RouteTreeSegmentData<'_> {
  243. pub fn to_tokens(
  244. &self,
  245. nests: &[Nest],
  246. tree: &ParseRouteTree,
  247. enum_name: syn::Ident,
  248. error_enum_name: syn::Ident,
  249. ) -> TokenStream {
  250. match self {
  251. RouteTreeSegmentData::Static {
  252. segment,
  253. children,
  254. index,
  255. error_variant:
  256. StaticErrorVariant {
  257. variant_parse_error,
  258. enum_variant,
  259. },
  260. } => {
  261. let children = children.iter().map(|child| {
  262. let child = tree.get(*child).unwrap();
  263. child.to_tokens(nests, tree, enum_name.clone(), error_enum_name.clone())
  264. });
  265. if segment.is_empty() {
  266. return quote! {
  267. {
  268. #(#children)*
  269. }
  270. };
  271. }
  272. let error_ident = static_segment_idx(*index);
  273. quote! {
  274. {
  275. let mut segments = segments.clone();
  276. let segment = segments.next();
  277. if let Some(segment) = segment.as_deref() {
  278. if #segment == segment {
  279. #(#children)*
  280. } else {
  281. errors.push(#error_enum_name::#enum_variant(#variant_parse_error::#error_ident(segment.to_string())))
  282. }
  283. }
  284. }
  285. }
  286. }
  287. RouteTreeSegmentData::Route(route) => {
  288. // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
  289. let variant_parse_error = route.error_ident();
  290. let enum_variant = &route.route_name;
  291. let route_segments = route
  292. .segments
  293. .iter()
  294. .enumerate()
  295. .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)))
  296. .filter(|(i, _)| {
  297. // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
  298. !is_trailing_static_segment(&route.segments, *i)
  299. });
  300. let construct_variant = route.construct(nests, enum_name);
  301. let parse_query = route.parse_query();
  302. let parse_hash = route.parse_hash();
  303. let insure_not_trailing = match route.ty {
  304. RouteType::Leaf { .. } => route
  305. .segments
  306. .last()
  307. .map(|seg| !matches!(seg, RouteSegment::CatchAll(_, _)))
  308. .unwrap_or(true),
  309. RouteType::Child(_) => false,
  310. };
  311. let print_route_segment = print_route_segment(
  312. route_segments.peekable(),
  313. return_constructed(
  314. insure_not_trailing,
  315. construct_variant,
  316. &error_enum_name,
  317. enum_variant,
  318. &variant_parse_error,
  319. parse_query,
  320. parse_hash,
  321. ),
  322. &error_enum_name,
  323. enum_variant,
  324. &variant_parse_error,
  325. );
  326. match &route.ty {
  327. RouteType::Child(child) => {
  328. let ty = &child.ty;
  329. let child_name = &child.ident;
  330. quote! {
  331. let mut trailing = String::from("/");
  332. for seg in segments.clone() {
  333. trailing += &*seg;
  334. trailing += "/";
  335. }
  336. match #ty::from_str(&trailing).map_err(|err| #error_enum_name::#enum_variant(#variant_parse_error::ChildRoute(err))) {
  337. Ok(#child_name) => {
  338. #print_route_segment
  339. }
  340. Err(err) => {
  341. errors.push(err);
  342. }
  343. }
  344. }
  345. }
  346. RouteType::Leaf { .. } => print_route_segment,
  347. }
  348. }
  349. Self::Nest { nest, children } => {
  350. // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
  351. let variant_parse_error: Ident = nest.error_ident();
  352. let enum_variant = nest.error_variant();
  353. let route_segments = nest
  354. .segments
  355. .iter()
  356. .enumerate()
  357. .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)));
  358. let parse_children = children
  359. .iter()
  360. .map(|child| {
  361. let child = tree.get(*child).unwrap();
  362. child.to_tokens(nests, tree, enum_name.clone(), error_enum_name.clone())
  363. })
  364. .collect();
  365. print_route_segment(
  366. route_segments.peekable(),
  367. parse_children,
  368. &error_enum_name,
  369. &enum_variant,
  370. &variant_parse_error,
  371. )
  372. }
  373. Self::Redirect(redirect) => {
  374. // At this point, we have matched all static segments, so we can just check if the remaining segments match the route
  375. let variant_parse_error = redirect.error_ident();
  376. let enum_variant = &redirect.error_variant();
  377. let route_segments = redirect
  378. .segments
  379. .iter()
  380. .enumerate()
  381. .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)));
  382. let parse_query = redirect.parse_query();
  383. let parse_hash = redirect.parse_hash();
  384. let insure_not_trailing = redirect
  385. .segments
  386. .last()
  387. .map(|seg| !matches!(seg, RouteSegment::CatchAll(_, _)))
  388. .unwrap_or(true);
  389. let redirect_function = &redirect.function;
  390. let args = redirect_function.inputs.iter().map(|pat| match pat {
  391. syn::Pat::Type(ident) => {
  392. let name = &ident.pat;
  393. quote! {#name}
  394. }
  395. _ => panic!("Expected closure argument to be a typed pattern"),
  396. });
  397. let return_redirect = quote! {
  398. (#redirect_function)(#(#args,)*)
  399. };
  400. print_route_segment(
  401. route_segments.peekable(),
  402. return_constructed(
  403. insure_not_trailing,
  404. return_redirect,
  405. &error_enum_name,
  406. enum_variant,
  407. &variant_parse_error,
  408. parse_query,
  409. parse_hash,
  410. ),
  411. &error_enum_name,
  412. enum_variant,
  413. &variant_parse_error,
  414. )
  415. }
  416. }
  417. }
  418. }
  419. fn print_route_segment<'a, I: Iterator<Item = (usize, &'a RouteSegment)>>(
  420. mut s: std::iter::Peekable<I>,
  421. success_tokens: TokenStream,
  422. error_enum_name: &Ident,
  423. enum_variant: &Ident,
  424. variant_parse_error: &Ident,
  425. ) -> TokenStream {
  426. if let Some((i, route)) = s.next() {
  427. let children = print_route_segment(
  428. s,
  429. success_tokens,
  430. error_enum_name,
  431. enum_variant,
  432. variant_parse_error,
  433. );
  434. route.try_parse(
  435. i,
  436. error_enum_name,
  437. enum_variant,
  438. variant_parse_error,
  439. children,
  440. )
  441. } else {
  442. quote! {
  443. #success_tokens
  444. }
  445. }
  446. }
  447. fn return_constructed(
  448. insure_not_trailing: bool,
  449. construct_variant: TokenStream,
  450. error_enum_name: &Ident,
  451. enum_variant: &Ident,
  452. variant_parse_error: &Ident,
  453. parse_query: TokenStream,
  454. parse_hash: TokenStream,
  455. ) -> TokenStream {
  456. if insure_not_trailing {
  457. quote! {
  458. let remaining_segments = segments.clone();
  459. let mut segments_clone = segments.clone();
  460. let next_segment = segments_clone.next();
  461. // This is the last segment, return the parsed route
  462. if next_segment.is_none() {
  463. #parse_query
  464. #parse_hash
  465. return Ok(#construct_variant);
  466. } else {
  467. let mut trailing = String::new();
  468. for seg in remaining_segments {
  469. trailing += &*seg;
  470. trailing += "/";
  471. }
  472. trailing.pop();
  473. errors.push(#error_enum_name::#enum_variant(#variant_parse_error::ExtraSegments(trailing)))
  474. }
  475. }
  476. } else {
  477. quote! {
  478. #parse_query
  479. #parse_hash
  480. return Ok(#construct_variant);
  481. }
  482. }
  483. }
  484. pub struct PathIter<'a> {
  485. final_segment: RouteTreeSegmentData<'a>,
  486. active_nests: &'a [NestId],
  487. all_nests: &'a [Nest],
  488. segments: &'a [RouteSegment],
  489. error_ident: Ident,
  490. error_variant: Ident,
  491. nest_index: usize,
  492. static_segment_index: usize,
  493. }
  494. impl<'a> PathIter<'a> {
  495. fn new_route(route: &'a Route, nests: &'a [Nest]) -> Self {
  496. Self {
  497. final_segment: RouteTreeSegmentData::Route(route),
  498. active_nests: &*route.nests,
  499. segments: &*route.segments,
  500. error_ident: route.error_ident(),
  501. error_variant: route.route_name.clone(),
  502. all_nests: nests,
  503. nest_index: 0,
  504. static_segment_index: 0,
  505. }
  506. }
  507. fn new_redirect(redirect: &'a Redirect, nests: &'a [Nest]) -> Self {
  508. Self {
  509. final_segment: RouteTreeSegmentData::Redirect(redirect),
  510. active_nests: &*redirect.nests,
  511. segments: &*redirect.segments,
  512. error_ident: redirect.error_ident(),
  513. error_variant: redirect.error_variant(),
  514. all_nests: nests,
  515. nest_index: 0,
  516. static_segment_index: 0,
  517. }
  518. }
  519. fn next_nest(&mut self) -> Option<&'a Nest> {
  520. let idx = self.nest_index;
  521. let nest_index = self.active_nests.get(idx)?;
  522. let nest = &self.all_nests[nest_index.0];
  523. self.nest_index += 1;
  524. Some(nest)
  525. }
  526. fn next_static_segment(&mut self) -> Option<(usize, &'a str)> {
  527. let idx = self.static_segment_index;
  528. let segment = self.segments.get(idx)?;
  529. // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
  530. if is_trailing_static_segment(self.segments, idx) {
  531. return None;
  532. }
  533. match segment {
  534. RouteSegment::Static(segment) => {
  535. self.static_segment_index += 1;
  536. Some((idx, segment))
  537. }
  538. _ => None,
  539. }
  540. }
  541. fn error_variant(&self) -> StaticErrorVariant {
  542. StaticErrorVariant {
  543. variant_parse_error: self.error_ident.clone(),
  544. enum_variant: self.error_variant.clone(),
  545. }
  546. }
  547. }
  548. // If this is the last segment and it is an empty trailing segment, skip parsing it. The parsing code handles parsing /path/ and /path
  549. pub(crate) fn is_trailing_static_segment(segments: &[RouteSegment], index: usize) -> bool {
  550. // This can only be a trailing segment if we have more than one segment and this is the last segment
  551. matches!(segments.get(index), Some(RouteSegment::Static(segment)) if segment.is_empty() && index == segments.len() - 1 && segments.len() > 1)
  552. }