Skip to content

Rework FromStr derive #468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 8, 2025
262 changes: 171 additions & 91 deletions impl/src/from_str.rs
Original file line number Diff line number Diff line change
@@ -1,113 +1,193 @@
use crate::utils::{DeriveType, HashMap};
use crate::utils::{SingleFieldData, State};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse::Result, DeriveInput};
//! Implementation of a [`FromStr`] derive macro.

/// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr`
pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
let state = State::new(input, trait_name, trait_name.to_lowercase())?;
use std::collections::HashMap;
#[cfg(doc)]
use std::str::FromStr;

if state.derive_type == DeriveType::Enum {
Ok(enum_from(input, state, trait_name))
} else {
Ok(struct_from(&state, trait_name))
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_quote, spanned::Spanned as _};

/// Expands a [`FromStr`] derive macro.
pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStream> {
match &input.data {
syn::Data::Struct(_) => {
Ok(ForwardExpansion::try_from(input)?.into_token_stream())
}
syn::Data::Enum(_) => {
Ok(EnumFlatExpansion::try_from(input)?.into_token_stream())
}
syn::Data::Union(data) => Err(syn::Error::new(
data.union_token.span(),
"`FromStr` cannot be derived for unions",
)),
}
}

pub fn struct_from(state: &State, trait_name: &'static str) -> TokenStream {
// We cannot set defaults for fields, once we do we can remove this check
if state.fields.len() != 1 || state.enabled_fields().len() != 1 {
panic_one_field(trait_name);
}
/// Expansion of a macro for generating a forwarding [`FromStr`] implementation of a struct.
struct ForwardExpansion<'i> {
/// [`syn::Ident`] and [`syn::Generics`] of the struct.
///
/// [`syn::Ident`]: struct@syn::Ident
self_ty: (&'i syn::Ident, &'i syn::Generics),

let single_field_data = state.assert_single_enabled_field();
let SingleFieldData {
input_type,
field_type,
trait_path,
casted_trait,
impl_generics,
ty_generics,
where_clause,
..
} = single_field_data.clone();

let initializers = [quote! { #casted_trait::from_str(src)? }];
let body = single_field_data.initializer(&initializers);
let error = quote! { <#field_type as #trait_path>::Err };

quote! {
#[automatically_derived]
impl #impl_generics #trait_path for #input_type #ty_generics #where_clause {
type Err = #error;

#[inline]
fn from_str(src: &str) -> derive_more::core::result::Result<Self, #error> {
derive_more::core::result::Result::Ok(#body)
}
}
}
/// [`syn::Field`] representing the wrapped type to forward implementation on.
inner: &'i syn::Field,
}

fn enum_from(
input: &DeriveInput,
state: State,
trait_name: &'static str,
) -> TokenStream {
let mut variants_caseinsensitive = HashMap::default();
for variant_state in state.enabled_variant_data().variant_states {
let variant = variant_state.variant.unwrap();
if !variant.fields.is_empty() {
panic!("Only enums with no fields can derive({trait_name})")
impl<'i> TryFrom<&'i syn::DeriveInput> for ForwardExpansion<'i> {
type Error = syn::Error;

fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> {
let syn::Data::Struct(data) = &input.data else {
return Err(syn::Error::new(
input.span(),
"expected a struct for forward `FromStr` derive",
));
};

// TODO: Unite these two conditions via `&&` once MSRV is bumped to 1.88 or above.
if data.fields.len() != 1 {
return Err(syn::Error::new(
data.fields.span(),
"only structs with single field can derive `FromStr`",
));
}

variants_caseinsensitive
.entry(variant.ident.to_string().to_lowercase())
.or_insert_with(Vec::new)
.push(variant.ident.clone());
let Some(inner) = data.fields.iter().next() else {
return Err(syn::Error::new(
data.fields.span(),
"only structs with single field can derive `FromStr`",
));
};

Ok(Self {
self_ty: (&input.ident, &input.generics),
inner,
})
}
}

let input_type = &input.ident;
let input_type_name = input_type.to_string();

let mut cases = vec![];
impl ToTokens for ForwardExpansion<'_> {
/// Expands a forwarding [`FromStr`] implementations for a struct.
fn to_tokens(&self, tokens: &mut TokenStream) {
let inner_ty = &self.inner.ty;
let ty = self.self_ty.0;

let mut generics = self.self_ty.1.clone();
if !generics.params.is_empty() {
generics.make_where_clause().predicates.push(parse_quote! {
#inner_ty: derive_more::core::str::FromStr
});
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

// if a case insensitive match is unique match do that
// otherwise do a case sensitive match
for (ref canonical, ref variants) in variants_caseinsensitive {
if variants.len() == 1 {
let variant = &variants[0];
cases.push(quote! {
#canonical => #input_type::#variant,
})
let constructor = if let Some(name) = &self.inner.ident {
quote! { Self { #name: v } }
} else {
for variant in variants {
let variant_str = variant.to_string();
cases.push(quote! {
#canonical if(src == #variant_str) => #input_type::#variant,
})
quote! { Self(v) }
};

quote! {
#[automatically_derived]
impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause {
type Err = <#inner_ty as derive_more::core::str::FromStr>::Err;

#[inline]
fn from_str(s: &str) -> derive_more::core::result::Result<Self, Self::Err> {
derive_more::core::str::FromStr::from_str(s).map(|v| #constructor)
}
}
}
}.to_tokens(tokens);
}
}

let trait_path = state.trait_path;
/// Expansion of a macro for generating a flat [`FromStr`] implementation of an enum.
struct EnumFlatExpansion<'i> {
/// [`syn::Ident`] and [`syn::Generics`] of the enum.
///
/// [`syn::Ident`]: struct@syn::Ident
self_ty: (&'i syn::Ident, &'i syn::Generics),

/// [`syn::Ident`]s of the enum variants.
///
/// [`syn::Ident`]: struct@syn::Ident
variants: Vec<&'i syn::Ident>,
}

quote! {
impl #trait_path for #input_type {
type Err = derive_more::FromStrError;
impl<'i> TryFrom<&'i syn::DeriveInput> for EnumFlatExpansion<'i> {
type Error = syn::Error;

fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> {
let syn::Data::Enum(data) = &input.data else {
return Err(syn::Error::new(
input.span(),
"expected an enum for flat `FromStr` derive",
));
};

let variants = data
.variants
.iter()
.map(|variant| {
if !variant.fields.is_empty() {
return Err(syn::Error::new(
variant.fields.span(),
"only enums with no fields can derive `FromStr`",
));
}
Ok(&variant.ident)
})
.collect::<syn::Result<_>>()?;

#[inline]
fn from_str(src: &str) -> derive_more::core::result::Result<Self, derive_more::FromStrError> {
Ok(match src.to_lowercase().as_str() {
#(#cases)*
_ => return Err(derive_more::FromStrError::new(#input_type_name)),
})
}
}
Ok(Self {
self_ty: (&input.ident, &input.generics),
variants,
})
}
}

fn panic_one_field(trait_name: &str) -> ! {
panic!("Only structs with one field can derive({trait_name})")
impl ToTokens for EnumFlatExpansion<'_> {
/// Expands a flat [`FromStr`] implementations for an enum.
fn to_tokens(&self, tokens: &mut TokenStream) {
let ty = self.self_ty.0;
let (impl_generics, ty_generics, where_clause) =
self.self_ty.1.split_for_impl();
let ty_name = ty.to_string();

let similar_lowercased = self
.variants
.iter()
.map(|v| v.to_string().to_lowercase())
.fold(<HashMap<_, u8>>::new(), |mut counts, v| {
*counts.entry(v).or_default() += 1;
counts
});

let match_arms = self.variants.iter().map(|variant| {
let name = variant.to_string();
let lowercased = name.to_lowercase();
let exact_guard =
(similar_lowercased[&lowercased] > 1).then(|| quote! { if s == #name });

quote! { #lowercased #exact_guard => Self::#variant, }
});

quote! {
#[automatically_derived]
impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause {
type Err = derive_more::FromStrError;

fn from_str(
s: &str,
) -> derive_more::core::result::Result<Self, derive_more::FromStrError> {
derive_more::core::result::Result::Ok(match s.to_lowercase().as_str() {
#( #match_arms )*
_ => return derive_more::core::result::Result::Err(
derive_more::FromStrError::new(#ty_name),
),
})
}
}
}.to_tokens(tokens);
}
}
22 changes: 3 additions & 19 deletions impl/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use quote::{format_ident, quote, ToTokens};
use syn::{
parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, Data,
DeriveInput, Error, Field, Fields, FieldsNamed, FieldsUnnamed, GenericParam,
Generics, Ident, ImplGenerics, Index, Result, Token, Type, TypeGenerics,
TypeParamBound, Variant, WhereClause,
Generics, Ident, Index, Result, Token, Type, TypeGenerics, TypeParamBound, Variant,
WhereClause,
};

#[cfg(any(
Expand Down Expand Up @@ -577,10 +577,7 @@ impl<'input> State<'input> {
trait_path: data.trait_path,
trait_path_with_params: data.trait_path_with_params.clone(),
casted_trait: data.casted_traits[0].clone(),
impl_generics: data.impl_generics.clone(),
ty_generics: data.ty_generics.clone(),
where_clause: data.where_clause,
multi_field_data: data,
}
}

Expand Down Expand Up @@ -608,7 +605,7 @@ impl<'input> State<'input> {
.iter()
.map(|field_type| quote! { <#field_type as #trait_path_with_params> })
.collect();
let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let (_, ty_generics, _) = self.generics.split_for_impl();
let input_type = &self.input.ident;
let (variant_name, variant_type) = self.variant.map_or_else(
|| (None, quote! { #input_type }),
Expand All @@ -632,9 +629,7 @@ impl<'input> State<'input> {
trait_path,
trait_path_with_params,
casted_traits,
impl_generics,
ty_generics,
where_clause,
state: self,
}
}
Expand Down Expand Up @@ -736,10 +731,7 @@ pub struct SingleFieldData<'input, 'state> {
pub trait_path: &'state TokenStream,
pub trait_path_with_params: TokenStream,
pub casted_trait: TokenStream,
pub impl_generics: ImplGenerics<'state>,
pub ty_generics: TypeGenerics<'state>,
pub where_clause: Option<&'state WhereClause>,
multi_field_data: MultiFieldData<'input, 'state>,
}

#[derive(Clone)]
Expand All @@ -758,9 +750,7 @@ pub struct MultiFieldData<'input, 'state> {
pub trait_path: &'state TokenStream,
pub trait_path_with_params: TokenStream,
pub casted_traits: Vec<TokenStream>,
pub impl_generics: ImplGenerics<'state>,
pub ty_generics: TypeGenerics<'state>,
pub where_clause: Option<&'state WhereClause>,
pub state: &'state State<'input>,
}

Expand Down Expand Up @@ -804,12 +794,6 @@ impl MultiFieldData<'_, '_> {
}
}

impl SingleFieldData<'_, '_> {
pub fn initializer<T: ToTokens>(&self, initializers: &[T]) -> TokenStream {
self.multi_field_data.initializer(initializers)
}
}

fn get_meta_info(
trait_attr: &str,
attrs: &[Attribute],
Expand Down
7 changes: 7 additions & 0 deletions tests/compile_fail/from_str/enum_variant_field.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#[derive(derive_more::FromStr)]
enum Enum {
Unit,
Tuple(i32),
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/compile_fail/from_str/enum_variant_field.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: only enums with no fields can derive `FromStr`
--> tests/compile_fail/from_str/enum_variant_field.rs:4:10
|
4 | Tuple(i32),
| ^^^^^
7 changes: 7 additions & 0 deletions tests/compile_fail/from_str/struct_multi_fields.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#[derive(derive_more::FromStr)]
pub struct Foo {
foo: i32,
bar: i32,
}

fn main() {}
9 changes: 9 additions & 0 deletions tests/compile_fail/from_str/struct_multi_fields.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
error: only structs with single field can derive `FromStr`
--> tests/compile_fail/from_str/struct_multi_fields.rs:2:16
|
2 | pub struct Foo {
| ________________^
3 | | foo: i32,
4 | | bar: i32,
5 | | }
| |_^
6 changes: 6 additions & 0 deletions tests/compile_fail/from_str/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#[derive(derive_more::FromStr)]
union IntOrFloat {
i: u32,
}

fn main() {}
Loading