Skip to content

Commit 1c2fe2a

Browse files
authored
Rework FromStr derive (#468, #467)
Required for #467, #469 This PR fully reworks `FromStr` derive implementation to avoid using old `utils` machinery. Also, adds `compile_fail` tests for `FromStr`.
1 parent 211165a commit 1c2fe2a

9 files changed

+403
-145
lines changed

impl/src/from_str.rs

Lines changed: 171 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,193 @@
1-
use crate::utils::{DeriveType, HashMap};
2-
use crate::utils::{SingleFieldData, State};
3-
use proc_macro2::TokenStream;
4-
use quote::quote;
5-
use syn::{parse::Result, DeriveInput};
1+
//! Implementation of a [`FromStr`] derive macro.
62
7-
/// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr`
8-
pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
9-
let state = State::new(input, trait_name, trait_name.to_lowercase())?;
3+
use std::collections::HashMap;
4+
#[cfg(doc)]
5+
use std::str::FromStr;
106

11-
if state.derive_type == DeriveType::Enum {
12-
Ok(enum_from(input, state, trait_name))
13-
} else {
14-
Ok(struct_from(&state, trait_name))
7+
use proc_macro2::TokenStream;
8+
use quote::{quote, ToTokens};
9+
use syn::{parse_quote, spanned::Spanned as _};
10+
11+
/// Expands a [`FromStr`] derive macro.
12+
pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStream> {
13+
match &input.data {
14+
syn::Data::Struct(_) => {
15+
Ok(ForwardExpansion::try_from(input)?.into_token_stream())
16+
}
17+
syn::Data::Enum(_) => {
18+
Ok(EnumFlatExpansion::try_from(input)?.into_token_stream())
19+
}
20+
syn::Data::Union(data) => Err(syn::Error::new(
21+
data.union_token.span(),
22+
"`FromStr` cannot be derived for unions",
23+
)),
1524
}
1625
}
1726

18-
pub fn struct_from(state: &State, trait_name: &'static str) -> TokenStream {
19-
// We cannot set defaults for fields, once we do we can remove this check
20-
if state.fields.len() != 1 || state.enabled_fields().len() != 1 {
21-
panic_one_field(trait_name);
22-
}
27+
/// Expansion of a macro for generating a forwarding [`FromStr`] implementation of a struct.
28+
struct ForwardExpansion<'i> {
29+
/// [`syn::Ident`] and [`syn::Generics`] of the struct.
30+
///
31+
/// [`syn::Ident`]: struct@syn::Ident
32+
self_ty: (&'i syn::Ident, &'i syn::Generics),
2333

24-
let single_field_data = state.assert_single_enabled_field();
25-
let SingleFieldData {
26-
input_type,
27-
field_type,
28-
trait_path,
29-
casted_trait,
30-
impl_generics,
31-
ty_generics,
32-
where_clause,
33-
..
34-
} = single_field_data.clone();
35-
36-
let initializers = [quote! { #casted_trait::from_str(src)? }];
37-
let body = single_field_data.initializer(&initializers);
38-
let error = quote! { <#field_type as #trait_path>::Err };
39-
40-
quote! {
41-
#[automatically_derived]
42-
impl #impl_generics #trait_path for #input_type #ty_generics #where_clause {
43-
type Err = #error;
44-
45-
#[inline]
46-
fn from_str(src: &str) -> derive_more::core::result::Result<Self, #error> {
47-
derive_more::core::result::Result::Ok(#body)
48-
}
49-
}
50-
}
34+
/// [`syn::Field`] representing the wrapped type to forward implementation on.
35+
inner: &'i syn::Field,
5136
}
5237

53-
fn enum_from(
54-
input: &DeriveInput,
55-
state: State,
56-
trait_name: &'static str,
57-
) -> TokenStream {
58-
let mut variants_caseinsensitive = HashMap::default();
59-
for variant_state in state.enabled_variant_data().variant_states {
60-
let variant = variant_state.variant.unwrap();
61-
if !variant.fields.is_empty() {
62-
panic!("Only enums with no fields can derive({trait_name})")
38+
impl<'i> TryFrom<&'i syn::DeriveInput> for ForwardExpansion<'i> {
39+
type Error = syn::Error;
40+
41+
fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> {
42+
let syn::Data::Struct(data) = &input.data else {
43+
return Err(syn::Error::new(
44+
input.span(),
45+
"expected a struct for forward `FromStr` derive",
46+
));
47+
};
48+
49+
// TODO: Unite these two conditions via `&&` once MSRV is bumped to 1.88 or above.
50+
if data.fields.len() != 1 {
51+
return Err(syn::Error::new(
52+
data.fields.span(),
53+
"only structs with single field can derive `FromStr`",
54+
));
6355
}
64-
65-
variants_caseinsensitive
66-
.entry(variant.ident.to_string().to_lowercase())
67-
.or_insert_with(Vec::new)
68-
.push(variant.ident.clone());
56+
let Some(inner) = data.fields.iter().next() else {
57+
return Err(syn::Error::new(
58+
data.fields.span(),
59+
"only structs with single field can derive `FromStr`",
60+
));
61+
};
62+
63+
Ok(Self {
64+
self_ty: (&input.ident, &input.generics),
65+
inner,
66+
})
6967
}
68+
}
7069

71-
let input_type = &input.ident;
72-
let input_type_name = input_type.to_string();
73-
74-
let mut cases = vec![];
70+
impl ToTokens for ForwardExpansion<'_> {
71+
/// Expands a forwarding [`FromStr`] implementations for a struct.
72+
fn to_tokens(&self, tokens: &mut TokenStream) {
73+
let inner_ty = &self.inner.ty;
74+
let ty = self.self_ty.0;
75+
76+
let mut generics = self.self_ty.1.clone();
77+
if !generics.params.is_empty() {
78+
generics.make_where_clause().predicates.push(parse_quote! {
79+
#inner_ty: derive_more::core::str::FromStr
80+
});
81+
}
82+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
7583

76-
// if a case insensitive match is unique match do that
77-
// otherwise do a case sensitive match
78-
for (ref canonical, ref variants) in variants_caseinsensitive {
79-
if variants.len() == 1 {
80-
let variant = &variants[0];
81-
cases.push(quote! {
82-
#canonical => #input_type::#variant,
83-
})
84+
let constructor = if let Some(name) = &self.inner.ident {
85+
quote! { Self { #name: v } }
8486
} else {
85-
for variant in variants {
86-
let variant_str = variant.to_string();
87-
cases.push(quote! {
88-
#canonical if(src == #variant_str) => #input_type::#variant,
89-
})
87+
quote! { Self(v) }
88+
};
89+
90+
quote! {
91+
#[automatically_derived]
92+
impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause {
93+
type Err = <#inner_ty as derive_more::core::str::FromStr>::Err;
94+
95+
#[inline]
96+
fn from_str(s: &str) -> derive_more::core::result::Result<Self, Self::Err> {
97+
derive_more::core::str::FromStr::from_str(s).map(|v| #constructor)
98+
}
9099
}
91-
}
100+
}.to_tokens(tokens);
92101
}
102+
}
93103

94-
let trait_path = state.trait_path;
104+
/// Expansion of a macro for generating a flat [`FromStr`] implementation of an enum.
105+
struct EnumFlatExpansion<'i> {
106+
/// [`syn::Ident`] and [`syn::Generics`] of the enum.
107+
///
108+
/// [`syn::Ident`]: struct@syn::Ident
109+
self_ty: (&'i syn::Ident, &'i syn::Generics),
110+
111+
/// [`syn::Ident`]s of the enum variants.
112+
///
113+
/// [`syn::Ident`]: struct@syn::Ident
114+
variants: Vec<&'i syn::Ident>,
115+
}
95116

96-
quote! {
97-
impl #trait_path for #input_type {
98-
type Err = derive_more::FromStrError;
117+
impl<'i> TryFrom<&'i syn::DeriveInput> for EnumFlatExpansion<'i> {
118+
type Error = syn::Error;
119+
120+
fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> {
121+
let syn::Data::Enum(data) = &input.data else {
122+
return Err(syn::Error::new(
123+
input.span(),
124+
"expected an enum for flat `FromStr` derive",
125+
));
126+
};
127+
128+
let variants = data
129+
.variants
130+
.iter()
131+
.map(|variant| {
132+
if !variant.fields.is_empty() {
133+
return Err(syn::Error::new(
134+
variant.fields.span(),
135+
"only enums with no fields can derive `FromStr`",
136+
));
137+
}
138+
Ok(&variant.ident)
139+
})
140+
.collect::<syn::Result<_>>()?;
99141

100-
#[inline]
101-
fn from_str(src: &str) -> derive_more::core::result::Result<Self, derive_more::FromStrError> {
102-
Ok(match src.to_lowercase().as_str() {
103-
#(#cases)*
104-
_ => return Err(derive_more::FromStrError::new(#input_type_name)),
105-
})
106-
}
107-
}
142+
Ok(Self {
143+
self_ty: (&input.ident, &input.generics),
144+
variants,
145+
})
108146
}
109147
}
110148

111-
fn panic_one_field(trait_name: &str) -> ! {
112-
panic!("Only structs with one field can derive({trait_name})")
149+
impl ToTokens for EnumFlatExpansion<'_> {
150+
/// Expands a flat [`FromStr`] implementations for an enum.
151+
fn to_tokens(&self, tokens: &mut TokenStream) {
152+
let ty = self.self_ty.0;
153+
let (impl_generics, ty_generics, where_clause) =
154+
self.self_ty.1.split_for_impl();
155+
let ty_name = ty.to_string();
156+
157+
let similar_lowercased = self
158+
.variants
159+
.iter()
160+
.map(|v| v.to_string().to_lowercase())
161+
.fold(<HashMap<_, u8>>::new(), |mut counts, v| {
162+
*counts.entry(v).or_default() += 1;
163+
counts
164+
});
165+
166+
let match_arms = self.variants.iter().map(|variant| {
167+
let name = variant.to_string();
168+
let lowercased = name.to_lowercase();
169+
let exact_guard =
170+
(similar_lowercased[&lowercased] > 1).then(|| quote! { if s == #name });
171+
172+
quote! { #lowercased #exact_guard => Self::#variant, }
173+
});
174+
175+
quote! {
176+
#[automatically_derived]
177+
impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause {
178+
type Err = derive_more::FromStrError;
179+
180+
fn from_str(
181+
s: &str,
182+
) -> derive_more::core::result::Result<Self, derive_more::FromStrError> {
183+
derive_more::core::result::Result::Ok(match s.to_lowercase().as_str() {
184+
#( #match_arms )*
185+
_ => return derive_more::core::result::Result::Err(
186+
derive_more::FromStrError::new(#ty_name),
187+
),
188+
})
189+
}
190+
}
191+
}.to_tokens(tokens);
192+
}
113193
}

impl/src/utils.rs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use quote::{format_ident, quote, ToTokens};
88
use syn::{
99
parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, Data,
1010
DeriveInput, Error, Field, Fields, FieldsNamed, FieldsUnnamed, GenericParam,
11-
Generics, Ident, ImplGenerics, Index, Result, Token, Type, TypeGenerics,
12-
TypeParamBound, Variant, WhereClause,
11+
Generics, Ident, Index, Result, Token, Type, TypeGenerics, TypeParamBound, Variant,
12+
WhereClause,
1313
};
1414

1515
#[cfg(any(
@@ -577,10 +577,7 @@ impl<'input> State<'input> {
577577
trait_path: data.trait_path,
578578
trait_path_with_params: data.trait_path_with_params.clone(),
579579
casted_trait: data.casted_traits[0].clone(),
580-
impl_generics: data.impl_generics.clone(),
581580
ty_generics: data.ty_generics.clone(),
582-
where_clause: data.where_clause,
583-
multi_field_data: data,
584581
}
585582
}
586583

@@ -608,7 +605,7 @@ impl<'input> State<'input> {
608605
.iter()
609606
.map(|field_type| quote! { <#field_type as #trait_path_with_params> })
610607
.collect();
611-
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
608+
let (_, ty_generics, _) = self.generics.split_for_impl();
612609
let input_type = &self.input.ident;
613610
let (variant_name, variant_type) = self.variant.map_or_else(
614611
|| (None, quote! { #input_type }),
@@ -632,9 +629,7 @@ impl<'input> State<'input> {
632629
trait_path,
633630
trait_path_with_params,
634631
casted_traits,
635-
impl_generics,
636632
ty_generics,
637-
where_clause,
638633
state: self,
639634
}
640635
}
@@ -736,10 +731,7 @@ pub struct SingleFieldData<'input, 'state> {
736731
pub trait_path: &'state TokenStream,
737732
pub trait_path_with_params: TokenStream,
738733
pub casted_trait: TokenStream,
739-
pub impl_generics: ImplGenerics<'state>,
740734
pub ty_generics: TypeGenerics<'state>,
741-
pub where_clause: Option<&'state WhereClause>,
742-
multi_field_data: MultiFieldData<'input, 'state>,
743735
}
744736

745737
#[derive(Clone)]
@@ -758,9 +750,7 @@ pub struct MultiFieldData<'input, 'state> {
758750
pub trait_path: &'state TokenStream,
759751
pub trait_path_with_params: TokenStream,
760752
pub casted_traits: Vec<TokenStream>,
761-
pub impl_generics: ImplGenerics<'state>,
762753
pub ty_generics: TypeGenerics<'state>,
763-
pub where_clause: Option<&'state WhereClause>,
764754
pub state: &'state State<'input>,
765755
}
766756

@@ -804,12 +794,6 @@ impl MultiFieldData<'_, '_> {
804794
}
805795
}
806796

807-
impl SingleFieldData<'_, '_> {
808-
pub fn initializer<T: ToTokens>(&self, initializers: &[T]) -> TokenStream {
809-
self.multi_field_data.initializer(initializers)
810-
}
811-
}
812-
813797
fn get_meta_info(
814798
trait_attr: &str,
815799
attrs: &[Attribute],
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#[derive(derive_more::FromStr)]
2+
enum Enum {
3+
Unit,
4+
Tuple(i32),
5+
}
6+
7+
fn main() {}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
error: only enums with no fields can derive `FromStr`
2+
--> tests/compile_fail/from_str/enum_variant_field.rs:4:10
3+
|
4+
4 | Tuple(i32),
5+
| ^^^^^
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#[derive(derive_more::FromStr)]
2+
pub struct Foo {
3+
foo: i32,
4+
bar: i32,
5+
}
6+
7+
fn main() {}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
error: only structs with single field can derive `FromStr`
2+
--> tests/compile_fail/from_str/struct_multi_fields.rs:2:16
3+
|
4+
2 | pub struct Foo {
5+
| ________________^
6+
3 | | foo: i32,
7+
4 | | bar: i32,
8+
5 | | }
9+
| |_^

tests/compile_fail/from_str/union.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#[derive(derive_more::FromStr)]
2+
union IntOrFloat {
3+
i: u32,
4+
}
5+
6+
fn main() {}

0 commit comments

Comments
 (0)