瀏覽代碼

allow child routers to have fields

Evan Almloff 2 年之前
父節點
當前提交
914767892c
共有 2 個文件被更改,包括 79 次插入48 次删除
  1. 1 1
      packages/router-macro/src/lib.rs
  2. 78 47
      packages/router-macro/src/route.rs

+ 1 - 1
packages/router-macro/src/lib.rs

@@ -293,7 +293,7 @@ impl RouteEnum {
             let mut segment = SiteMapSegment::new(&route.segments);
             if let RouteType::Child(child) = &route.ty {
                 let new_segment = SiteMapSegment {
-                    segment_type: SegmentType::Child(child.clone()),
+                    segment_type: SegmentType::Child(child.ty.clone()),
                     children: Vec::new(),
                 };
                 match &mut segment {

+ 78 - 47
packages/router-macro/src/route.rs

@@ -2,6 +2,7 @@ use quote::{format_ident, quote};
 use syn::parse::Parse;
 use syn::parse::ParseStream;
 use syn::parse_quote;
+use syn::Field;
 use syn::Path;
 use syn::Type;
 use syn::{Ident, LitStr};
@@ -110,19 +111,30 @@ impl Route {
                     let args = route_attr.parse_args::<ChildArgs>()?;
                     route = args.route.value();
                     match &variant.fields {
-                        syn::Fields::Unnamed(fields) => {
-                            if fields.unnamed.len() != 1 {
-                                return Err(syn::Error::new_spanned(
-                                    variant.clone(),
-                                    "Routable variants with a #[child(...)] attribute must have exactly one field",
-                                ));
+                        syn::Fields::Named(fields) => {
+                            // find either a field with #[child] or a field named "child"
+                            let child_field = fields.named.iter().find(|f| {
+                                f.attrs
+                                    .iter()
+                                    .any(|attr| attr.path.is_ident("child"))
+                                    || f.ident.as_ref().unwrap().to_string() == "child"
+                            });
+                            match child_field{
+                                Some(child) => {
+                                    ty = RouteType::Child(child.clone());
+                                }
+                                None => {
+                                    return Err(syn::Error::new_spanned(
+                                        variant.clone(),
+                                        "Routable variants with a #[child(...)] attribute must have a field named \"child\" or a field with a #[child] attribute",
+                                    ));
+                                }
                             }
-                            ty = RouteType::Child(fields.unnamed[0].ty.clone());
                         }
                         _ => {
                             return Err(syn::Error::new_spanned(
                                 variant.clone(),
-                                "Routable variants with a #[child(...)] attribute must have exactly one field",
+                                "Routable variants with a #[child(...)] attribute must have named fields",
                             ))
                         }
                     }
@@ -139,7 +151,14 @@ impl Route {
             syn::Fields::Named(fields) => fields
                 .named
                 .iter()
-                .map(|f| (f.ident.clone().unwrap(), f.ty.clone()))
+                .filter_map(|f| {
+                    if let RouteType::Child(child) = &ty {
+                        if f.ident == child.ident {
+                            return None;
+                        }
+                    }
+                    Some((f.ident.clone().unwrap(), f.ty.clone()))
+                })
                 .collect(),
             _ => Vec::new(),
         };
@@ -172,13 +191,14 @@ impl Route {
         let write_query = self.query.as_ref().map(|q| q.write());
 
         match &self.ty {
-            RouteType::Child(_) => {
+            RouteType::Child(field) => {
+                let child = field.ident.as_ref().unwrap();
                 quote! {
-                    Self::#name(child) => {
+                    Self::#name { #(#dynamic_segments,)* #child } => {
                         use std::fmt::Display;
                         #(#write_layouts)*
                         #(#write_segments)*
-                        child.fmt(f);
+                        #child.fmt(f);
                     }
                 }
             }
@@ -204,10 +224,15 @@ impl Route {
         for (idx, layout_id) in self.layouts.iter().copied().enumerate() {
             let render_layout = layouts[layout_id.0].routable_match(nests);
             let dynamic_segments = self.dynamic_segments();
+            let mut field_name = None;
+            if let RouteType::Child(field) = &self.ty {
+                field_name = field.ident.as_ref();
+            }
+            let field_name = field_name.map(|f| quote!(#f,));
             // This is a layout
             tokens.extend(quote! {
                 #[allow(unused)]
-                (#idx, Self::#name { #(#dynamic_segments,)* }) => {
+                (#idx, Self::#name { #(#dynamic_segments,)* #field_name .. }) => {
                     #render_layout
                 }
             });
@@ -216,11 +241,12 @@ impl Route {
         // Then match the route
         let last_index = self.layouts.len();
         tokens.extend(match &self.ty {
-            RouteType::Child(_) => {
+            RouteType::Child(field) => {
+                let field_name = field.ident.as_ref().unwrap();
                 quote! {
                     #[allow(unused)]
-                    (#last_index.., Self::#name(child_route)) => {
-                        child_route.render(cx, level - #last_index)
+                    (#last_index.., Self::#name { #field_name, .. }) => {
+                        #field_name.render(cx, level - #last_index)
                     }
                 }
             }
@@ -250,9 +276,39 @@ impl Route {
     }
 
     pub fn construct(&self, nests: &[Nest], enum_name: Ident) -> TokenStream2 {
+        let segments = self.fields.iter().map(|(name, _)| {
+            let mut from_route = false;
+
+            for id in &self.nests {
+                let nest = &nests[id.0];
+                if nest.dynamic_segments_names().any(|i| &i == name) {
+                    from_route = true
+                }
+            }
+            for segment in &self.segments {
+                if let RouteSegment::Dynamic(other, ..) = segment {
+                    if other == name {
+                        from_route = true
+                    }
+                }
+            }
+            if let Some(query) = &self.query {
+                if &query.ident == name {
+                    from_route = true
+                }
+            }
+
+            if from_route {
+                quote! {#name}
+            } else {
+                quote! {#name: Default::default()}
+            }
+        });
         match &self.ty {
-            RouteType::Child(ty) => {
+            RouteType::Child(field) => {
                 let name = &self.route_name;
+                let child_name = field.ident.as_ref().unwrap();
+                let ty = &field.ty;
 
                 quote! {
                     {
@@ -262,39 +318,14 @@ impl Route {
                             trailing += "/";
                         }
                         trailing.pop();
-                        #enum_name::#name(#ty::from_str(&trailing).unwrap())
+                        #enum_name::#name {
+                            #child_name: #ty::from_str(&trailing).unwrap(),
+                            #(#segments,)*
+                        }
                     }
                 }
             }
             RouteType::Leaf { .. } => {
-                let segments = self.fields.iter().map(|(name, _)| {
-                    let mut from_route = false;
-
-                    for id in &self.nests {
-                        let nest = &nests[id.0];
-                        if nest.dynamic_segments_names().any(|i| &i == name) {
-                            from_route = true
-                        }
-                    }
-                    for segment in &self.segments {
-                        if let RouteSegment::Dynamic(other, ..) = segment {
-                            if other == name {
-                                from_route = true
-                            }
-                        }
-                    }
-                    if let Some(query) = &self.query {
-                        if &query.ident == name {
-                            from_route = true
-                        }
-                    }
-
-                    if from_route {
-                        quote! {#name}
-                    } else {
-                        quote! {#name: Default::default()}
-                    }
-                });
                 let name = &self.route_name;
 
                 quote! {
@@ -326,6 +357,6 @@ impl Route {
 
 #[derive(Debug)]
 pub(crate) enum RouteType {
-    Child(Type),
+    Child(Field),
     Leaf { component: Path, props: Path },
 }