Pārlūkot izejas kodu

Parse trailing route slash (#2896)

* Parse trailing route slash

* Fix typo
Evan Almloff 9 mēneši atpakaļ
vecāks
revīzija
7bb53fe835

+ 1 - 1
packages/fullstack/examples/router/src/main.rs

@@ -24,7 +24,7 @@ enum Route {
     #[route("/")]
     Home {},
 
-    #[route("/blog/:id")]
+    #[route("/blog/:id/")]
     Blog { id: i32 },
 }
 

+ 6 - 4
packages/router-macro/src/lib.rs

@@ -14,7 +14,7 @@ use syn::{parse::ParseStream, parse_macro_input, Ident, Token, Type};
 
 use proc_macro2::TokenStream as TokenStream2;
 
-use crate::{layout::LayoutId, route_tree::RouteTree};
+use crate::{layout::LayoutId, route_tree::ParseRouteTree};
 
 mod hash;
 mod layout;
@@ -574,7 +574,7 @@ impl RouteEnum {
     }
 
     fn parse_impl(&self) -> TokenStream2 {
-        let tree = RouteTree::new(&self.endpoints, &self.nests);
+        let tree = ParseRouteTree::new(&self.endpoints, &self.nests);
         let name = &self.name;
 
         let error_name = format_ident!("{}MatchError", self.name);
@@ -599,14 +599,16 @@ impl RouteEnum {
                     let route = s;
                     let (route, hash) = route.split_once('#').unwrap_or((route, ""));
                     let (route, query) = route.split_once('?').unwrap_or((route, ""));
+                    // Remove any trailing slashes. We parse /route/ and /route in the same way
+                    // Note: we don't use trim because it includes more code
+                    let route = route.strip_suffix('/').unwrap_or(route);
                     let query = dioxus_router::exports::urlencoding::decode(query).unwrap_or(query.into());
                     let hash = dioxus_router::exports::urlencoding::decode(hash).unwrap_or(hash.into());
                     let mut segments = route.split('/').map(|s| dioxus_router::exports::urlencoding::decode(s).unwrap_or(s.into()));
                     // skip the first empty segment
                     if s.starts_with('/') {
                         let _ = segments.next();
-                    }
-                    else {
+                    } else {
                         // if this route does not start with a slash, it is not a valid route
                         return Err(dioxus_router::routable::RouteParseError {
                             attempted_routes: Vec::new(),

+ 31 - 25
packages/router-macro/src/route_tree.rs

@@ -12,12 +12,12 @@ use crate::{
 };
 
 #[derive(Debug, Clone, Default)]
-pub(crate) struct RouteTree<'a> {
+pub(crate) struct ParseRouteTree<'a> {
     pub roots: Vec<usize>,
     entries: Slab<RouteTreeSegmentData<'a>>,
 }
 
-impl<'a> RouteTree<'a> {
+impl<'a> ParseRouteTree<'a> {
     pub fn get(&self, index: usize) -> Option<&RouteTreeSegmentData<'a>> {
         self.entries.get(index)
     }
@@ -278,7 +278,7 @@ impl<'a> RouteTreeSegmentData<'a> {
     pub fn to_tokens(
         &self,
         nests: &[Nest],
-        tree: &RouteTree,
+        tree: &ParseRouteTree,
         enum_name: syn::Ident,
         error_enum_name: syn::Ident,
     ) -> TokenStream {
@@ -315,8 +315,7 @@ impl<'a> RouteTreeSegmentData<'a> {
                         if let Some(segment) = segment.as_deref() {
                             if #segment == segment {
                                 #(#children)*
-                            }
-                            else {
+                            } else {
                                 errors.push(#error_enum_name::#enum_variant(#variant_parse_error::#error_ident(segment.to_string())))
                             }
                         }
@@ -332,7 +331,11 @@ impl<'a> RouteTreeSegmentData<'a> {
                     .segments
                     .iter()
                     .enumerate()
-                    .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)));
+                    .skip_while(|(_, seg)| matches!(seg, RouteSegment::Static(_)))
+                    .filter(|(i, _)| {
+                        // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
+                        !is_trailing_static_segment(&route.segments, *i)
+                    });
 
                 let construct_variant = route.construct(nests, enum_name);
                 let parse_query = route.parse_query();
@@ -374,7 +377,6 @@ impl<'a> RouteTreeSegmentData<'a> {
                                 trailing += &*seg;
                                 trailing += "/";
                             }
-                            trailing.pop();
                             match #ty::from_str(&trailing).map_err(|err| #error_enum_name::#enum_variant(#variant_parse_error::ChildRoute(err))) {
                                 Ok(#child_name) => {
                                     #print_route_segment
@@ -511,25 +513,19 @@ fn return_constructed(
             let remaining_segments = segments.clone();
             let mut segments_clone = segments.clone();
             let next_segment = segments_clone.next();
-            let next_segment = next_segment.as_deref();
-            let segment_after_next = segments_clone.next();
-            let segment_after_next = segment_after_next.as_deref();
-            match (next_segment, segment_after_next) {
-                // This is the last segment, return the parsed route
-                (None, _) | (Some(""), None) => {
-                    #parse_query
-                    #parse_hash
-                    return Ok(#construct_variant);
-                }
-                _ => {
-                    let mut trailing = String::new();
-                    for seg in remaining_segments {
-                        trailing += &*seg;
-                        trailing += "/";
-                    }
-                    trailing.pop();
-                    errors.push(#error_enum_name::#enum_variant(#variant_parse_error::ExtraSegments(trailing)))
+            // This is the last segment, return the parsed route
+            if next_segment.is_none() {
+                #parse_query
+                #parse_hash
+                return Ok(#construct_variant);
+            } else {
+                let mut trailing = String::new();
+                for seg in remaining_segments {
+                    trailing += &*seg;
+                    trailing += "/";
                 }
+                trailing.pop();
+                errors.push(#error_enum_name::#enum_variant(#variant_parse_error::ExtraSegments(trailing)))
             }
         }
     } else {
@@ -590,6 +586,10 @@ impl<'a> PathIter<'a> {
     fn next_static_segment(&mut self) -> Option<(usize, &'a str)> {
         let idx = self.static_segment_index;
         let segment = self.segments.get(idx)?;
+        // Don't add any trailing static segments. We strip them during parsing so that routes can accept either `/route/` and `/route`
+        if is_trailing_static_segment(self.segments, idx) {
+            return None;
+        }
         match segment {
             RouteSegment::Static(segment) => {
                 self.static_segment_index += 1;
@@ -606,3 +606,9 @@ impl<'a> PathIter<'a> {
         }
     }
 }
+
+// If this is the last segment and it is an empty trailing segment, skip parsing it. The parsing code handles parsing /path/ and /path
+pub(crate) fn is_trailing_static_segment(segments: &[RouteSegment], index: usize) -> bool {
+    // This can only be a trailing segment if we have more than one segment and this is the last segment
+    matches!(segments.get(index), Some(RouteSegment::Static(segment)) if segment.is_empty() && index == segments.len() - 1 && segments.len() > 1)
+}

+ 3 - 11
packages/router-macro/src/segment.rs

@@ -65,18 +65,10 @@ impl RouteSegment {
                         let mut segments = segments.clone();
                         let segment = segments.next();
                         let segment = segment.as_deref();
-                        let parsed = if let Some(#segment) = segment {
-                            Ok(())
+                        if let Some(#segment) = segment {
+                            #parse_children
                         } else {
-                            Err(#error_enum_name::#error_enum_variant(#inner_parse_enum::#error_name(segment.map(|s|s.to_string()).unwrap_or_default())))
-                        };
-                        match parsed {
-                            Ok(_) => {
-                                #parse_children
-                            }
-                            Err(err) => {
-                                errors.push(err);
-                            }
+                            errors.push(#error_enum_name::#error_enum_variant(#inner_parse_enum::#error_name(segment.map(|s|s.to_string()).unwrap_or_default())));
                         }
                     }
                 }

+ 77 - 0
packages/router/tests/parsing.rs

@@ -0,0 +1,77 @@
+use dioxus::prelude::*;
+use std::str::FromStr;
+
+#[component]
+fn Root() -> Element {
+    todo!()
+}
+
+#[component]
+fn Test() -> Element {
+    todo!()
+}
+
+#[component]
+fn Dynamic(id: usize) -> Element {
+    todo!()
+}
+
+// Make sure trailing '/'s work correctly
+#[test]
+fn trailing_slashes_parse() {
+    #[derive(Routable, Clone, Copy, PartialEq, Debug)]
+    enum Route {
+        #[route("/")]
+        Root {},
+        #[route("/test/")]
+        Test {},
+        #[route("/:id/test/")]
+        Dynamic { id: usize },
+    }
+
+    assert_eq!(Route::from_str("/").unwrap(), Route::Root {});
+    assert_eq!(Route::from_str("/test/").unwrap(), Route::Test {});
+    assert_eq!(Route::from_str("/test").unwrap(), Route::Test {});
+    assert_eq!(
+        Route::from_str("/123/test/").unwrap(),
+        Route::Dynamic { id: 123 }
+    );
+    assert_eq!(
+        Route::from_str("/123/test").unwrap(),
+        Route::Dynamic { id: 123 }
+    );
+}
+
+#[test]
+fn without_trailing_slashes_parse() {
+    #[derive(Routable, Clone, Copy, PartialEq, Debug)]
+    enum RouteWithoutTrailingSlash {
+        #[route("/")]
+        Root {},
+        #[route("/test")]
+        Test {},
+        #[route("/:id/test")]
+        Dynamic { id: usize },
+    }
+
+    assert_eq!(
+        RouteWithoutTrailingSlash::from_str("/").unwrap(),
+        RouteWithoutTrailingSlash::Root {}
+    );
+    assert_eq!(
+        RouteWithoutTrailingSlash::from_str("/test/").unwrap(),
+        RouteWithoutTrailingSlash::Test {}
+    );
+    assert_eq!(
+        RouteWithoutTrailingSlash::from_str("/test").unwrap(),
+        RouteWithoutTrailingSlash::Test {}
+    );
+    assert_eq!(
+        RouteWithoutTrailingSlash::from_str("/123/test/").unwrap(),
+        RouteWithoutTrailingSlash::Dynamic { id: 123 }
+    );
+    assert_eq!(
+        RouteWithoutTrailingSlash::from_str("/123/test").unwrap(),
+        RouteWithoutTrailingSlash::Dynamic { id: 123 }
+    );
+}