|
1 |
| -use crate::utils::{DeriveType, HashMap}; |
2 |
| -use crate::utils::{SingleFieldData, State}; |
3 |
| -use proc_macro2::TokenStream; |
4 |
| -use quote::quote; |
5 |
| -use syn::{parse::Result, DeriveInput}; |
| 1 | +//! Implementation of a [`FromStr`] derive macro. |
6 | 2 |
|
7 |
| -/// Provides the hook to expand `#[derive(FromStr)]` into an implementation of `FromStr` |
8 |
| -pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> { |
9 |
| - let state = State::new(input, trait_name, trait_name.to_lowercase())?; |
| 3 | +use std::collections::HashMap; |
| 4 | +#[cfg(doc)] |
| 5 | +use std::str::FromStr; |
10 | 6 |
|
11 |
| - if state.derive_type == DeriveType::Enum { |
12 |
| - Ok(enum_from(input, state, trait_name)) |
13 |
| - } else { |
14 |
| - Ok(struct_from(&state, trait_name)) |
| 7 | +use proc_macro2::TokenStream; |
| 8 | +use quote::{quote, ToTokens}; |
| 9 | +use syn::{parse_quote, spanned::Spanned as _}; |
| 10 | + |
| 11 | +/// Expands a [`FromStr`] derive macro. |
| 12 | +pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStream> { |
| 13 | + match &input.data { |
| 14 | + syn::Data::Struct(_) => { |
| 15 | + Ok(ForwardExpansion::try_from(input)?.into_token_stream()) |
| 16 | + } |
| 17 | + syn::Data::Enum(_) => { |
| 18 | + Ok(EnumFlatExpansion::try_from(input)?.into_token_stream()) |
| 19 | + } |
| 20 | + syn::Data::Union(data) => Err(syn::Error::new( |
| 21 | + data.union_token.span(), |
| 22 | + "`FromStr` cannot be derived for unions", |
| 23 | + )), |
15 | 24 | }
|
16 | 25 | }
|
17 | 26 |
|
18 |
| -pub fn struct_from(state: &State, trait_name: &'static str) -> TokenStream { |
19 |
| - // We cannot set defaults for fields, once we do we can remove this check |
20 |
| - if state.fields.len() != 1 || state.enabled_fields().len() != 1 { |
21 |
| - panic_one_field(trait_name); |
22 |
| - } |
| 27 | +/// Expansion of a macro for generating a forwarding [`FromStr`] implementation of a struct. |
| 28 | +struct ForwardExpansion<'i> { |
| 29 | + /// [`syn::Ident`] and [`syn::Generics`] of the struct. |
| 30 | + /// |
| 31 | + /// [`syn::Ident`]: struct@syn::Ident |
| 32 | + self_ty: (&'i syn::Ident, &'i syn::Generics), |
23 | 33 |
|
24 |
| - let single_field_data = state.assert_single_enabled_field(); |
25 |
| - let SingleFieldData { |
26 |
| - input_type, |
27 |
| - field_type, |
28 |
| - trait_path, |
29 |
| - casted_trait, |
30 |
| - impl_generics, |
31 |
| - ty_generics, |
32 |
| - where_clause, |
33 |
| - .. |
34 |
| - } = single_field_data.clone(); |
35 |
| - |
36 |
| - let initializers = [quote! { #casted_trait::from_str(src)? }]; |
37 |
| - let body = single_field_data.initializer(&initializers); |
38 |
| - let error = quote! { <#field_type as #trait_path>::Err }; |
39 |
| - |
40 |
| - quote! { |
41 |
| - #[automatically_derived] |
42 |
| - impl #impl_generics #trait_path for #input_type #ty_generics #where_clause { |
43 |
| - type Err = #error; |
44 |
| - |
45 |
| - #[inline] |
46 |
| - fn from_str(src: &str) -> derive_more::core::result::Result<Self, #error> { |
47 |
| - derive_more::core::result::Result::Ok(#body) |
48 |
| - } |
49 |
| - } |
50 |
| - } |
| 34 | + /// [`syn::Field`] representing the wrapped type to forward implementation on. |
| 35 | + inner: &'i syn::Field, |
51 | 36 | }
|
52 | 37 |
|
53 |
| -fn enum_from( |
54 |
| - input: &DeriveInput, |
55 |
| - state: State, |
56 |
| - trait_name: &'static str, |
57 |
| -) -> TokenStream { |
58 |
| - let mut variants_caseinsensitive = HashMap::default(); |
59 |
| - for variant_state in state.enabled_variant_data().variant_states { |
60 |
| - let variant = variant_state.variant.unwrap(); |
61 |
| - if !variant.fields.is_empty() { |
62 |
| - panic!("Only enums with no fields can derive({trait_name})") |
| 38 | +impl<'i> TryFrom<&'i syn::DeriveInput> for ForwardExpansion<'i> { |
| 39 | + type Error = syn::Error; |
| 40 | + |
| 41 | + fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> { |
| 42 | + let syn::Data::Struct(data) = &input.data else { |
| 43 | + return Err(syn::Error::new( |
| 44 | + input.span(), |
| 45 | + "expected a struct for forward `FromStr` derive", |
| 46 | + )); |
| 47 | + }; |
| 48 | + |
| 49 | + // TODO: Unite these two conditions via `&&` once MSRV is bumped to 1.88 or above. |
| 50 | + if data.fields.len() != 1 { |
| 51 | + return Err(syn::Error::new( |
| 52 | + data.fields.span(), |
| 53 | + "only structs with single field can derive `FromStr`", |
| 54 | + )); |
63 | 55 | }
|
64 |
| - |
65 |
| - variants_caseinsensitive |
66 |
| - .entry(variant.ident.to_string().to_lowercase()) |
67 |
| - .or_insert_with(Vec::new) |
68 |
| - .push(variant.ident.clone()); |
| 56 | + let Some(inner) = data.fields.iter().next() else { |
| 57 | + return Err(syn::Error::new( |
| 58 | + data.fields.span(), |
| 59 | + "only structs with single field can derive `FromStr`", |
| 60 | + )); |
| 61 | + }; |
| 62 | + |
| 63 | + Ok(Self { |
| 64 | + self_ty: (&input.ident, &input.generics), |
| 65 | + inner, |
| 66 | + }) |
69 | 67 | }
|
| 68 | +} |
70 | 69 |
|
71 |
| - let input_type = &input.ident; |
72 |
| - let input_type_name = input_type.to_string(); |
73 |
| - |
74 |
| - let mut cases = vec![]; |
| 70 | +impl ToTokens for ForwardExpansion<'_> { |
| 71 | + /// Expands a forwarding [`FromStr`] implementations for a struct. |
| 72 | + fn to_tokens(&self, tokens: &mut TokenStream) { |
| 73 | + let inner_ty = &self.inner.ty; |
| 74 | + let ty = self.self_ty.0; |
| 75 | + |
| 76 | + let mut generics = self.self_ty.1.clone(); |
| 77 | + if !generics.params.is_empty() { |
| 78 | + generics.make_where_clause().predicates.push(parse_quote! { |
| 79 | + #inner_ty: derive_more::core::str::FromStr |
| 80 | + }); |
| 81 | + } |
| 82 | + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); |
75 | 83 |
|
76 |
| - // if a case insensitive match is unique match do that |
77 |
| - // otherwise do a case sensitive match |
78 |
| - for (ref canonical, ref variants) in variants_caseinsensitive { |
79 |
| - if variants.len() == 1 { |
80 |
| - let variant = &variants[0]; |
81 |
| - cases.push(quote! { |
82 |
| - #canonical => #input_type::#variant, |
83 |
| - }) |
| 84 | + let constructor = if let Some(name) = &self.inner.ident { |
| 85 | + quote! { Self { #name: v } } |
84 | 86 | } else {
|
85 |
| - for variant in variants { |
86 |
| - let variant_str = variant.to_string(); |
87 |
| - cases.push(quote! { |
88 |
| - #canonical if(src == #variant_str) => #input_type::#variant, |
89 |
| - }) |
| 87 | + quote! { Self(v) } |
| 88 | + }; |
| 89 | + |
| 90 | + quote! { |
| 91 | + #[automatically_derived] |
| 92 | + impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause { |
| 93 | + type Err = <#inner_ty as derive_more::core::str::FromStr>::Err; |
| 94 | + |
| 95 | + #[inline] |
| 96 | + fn from_str(s: &str) -> derive_more::core::result::Result<Self, Self::Err> { |
| 97 | + derive_more::core::str::FromStr::from_str(s).map(|v| #constructor) |
| 98 | + } |
90 | 99 | }
|
91 |
| - } |
| 100 | + }.to_tokens(tokens); |
92 | 101 | }
|
| 102 | +} |
93 | 103 |
|
94 |
| - let trait_path = state.trait_path; |
| 104 | +/// Expansion of a macro for generating a flat [`FromStr`] implementation of an enum. |
| 105 | +struct EnumFlatExpansion<'i> { |
| 106 | + /// [`syn::Ident`] and [`syn::Generics`] of the enum. |
| 107 | + /// |
| 108 | + /// [`syn::Ident`]: struct@syn::Ident |
| 109 | + self_ty: (&'i syn::Ident, &'i syn::Generics), |
| 110 | + |
| 111 | + /// [`syn::Ident`]s of the enum variants. |
| 112 | + /// |
| 113 | + /// [`syn::Ident`]: struct@syn::Ident |
| 114 | + variants: Vec<&'i syn::Ident>, |
| 115 | +} |
95 | 116 |
|
96 |
| - quote! { |
97 |
| - impl #trait_path for #input_type { |
98 |
| - type Err = derive_more::FromStrError; |
| 117 | +impl<'i> TryFrom<&'i syn::DeriveInput> for EnumFlatExpansion<'i> { |
| 118 | + type Error = syn::Error; |
| 119 | + |
| 120 | + fn try_from(input: &'i syn::DeriveInput) -> syn::Result<Self> { |
| 121 | + let syn::Data::Enum(data) = &input.data else { |
| 122 | + return Err(syn::Error::new( |
| 123 | + input.span(), |
| 124 | + "expected an enum for flat `FromStr` derive", |
| 125 | + )); |
| 126 | + }; |
| 127 | + |
| 128 | + let variants = data |
| 129 | + .variants |
| 130 | + .iter() |
| 131 | + .map(|variant| { |
| 132 | + if !variant.fields.is_empty() { |
| 133 | + return Err(syn::Error::new( |
| 134 | + variant.fields.span(), |
| 135 | + "only enums with no fields can derive `FromStr`", |
| 136 | + )); |
| 137 | + } |
| 138 | + Ok(&variant.ident) |
| 139 | + }) |
| 140 | + .collect::<syn::Result<_>>()?; |
99 | 141 |
|
100 |
| - #[inline] |
101 |
| - fn from_str(src: &str) -> derive_more::core::result::Result<Self, derive_more::FromStrError> { |
102 |
| - Ok(match src.to_lowercase().as_str() { |
103 |
| - #(#cases)* |
104 |
| - _ => return Err(derive_more::FromStrError::new(#input_type_name)), |
105 |
| - }) |
106 |
| - } |
107 |
| - } |
| 142 | + Ok(Self { |
| 143 | + self_ty: (&input.ident, &input.generics), |
| 144 | + variants, |
| 145 | + }) |
108 | 146 | }
|
109 | 147 | }
|
110 | 148 |
|
111 |
| -fn panic_one_field(trait_name: &str) -> ! { |
112 |
| - panic!("Only structs with one field can derive({trait_name})") |
| 149 | +impl ToTokens for EnumFlatExpansion<'_> { |
| 150 | + /// Expands a flat [`FromStr`] implementations for an enum. |
| 151 | + fn to_tokens(&self, tokens: &mut TokenStream) { |
| 152 | + let ty = self.self_ty.0; |
| 153 | + let (impl_generics, ty_generics, where_clause) = |
| 154 | + self.self_ty.1.split_for_impl(); |
| 155 | + let ty_name = ty.to_string(); |
| 156 | + |
| 157 | + let similar_lowercased = self |
| 158 | + .variants |
| 159 | + .iter() |
| 160 | + .map(|v| v.to_string().to_lowercase()) |
| 161 | + .fold(<HashMap<_, u8>>::new(), |mut counts, v| { |
| 162 | + *counts.entry(v).or_default() += 1; |
| 163 | + counts |
| 164 | + }); |
| 165 | + |
| 166 | + let match_arms = self.variants.iter().map(|variant| { |
| 167 | + let name = variant.to_string(); |
| 168 | + let lowercased = name.to_lowercase(); |
| 169 | + let exact_guard = |
| 170 | + (similar_lowercased[&lowercased] > 1).then(|| quote! { if s == #name }); |
| 171 | + |
| 172 | + quote! { #lowercased #exact_guard => Self::#variant, } |
| 173 | + }); |
| 174 | + |
| 175 | + quote! { |
| 176 | + #[automatically_derived] |
| 177 | + impl #impl_generics derive_more::core::str::FromStr for #ty #ty_generics #where_clause { |
| 178 | + type Err = derive_more::FromStrError; |
| 179 | + |
| 180 | + fn from_str( |
| 181 | + s: &str, |
| 182 | + ) -> derive_more::core::result::Result<Self, derive_more::FromStrError> { |
| 183 | + derive_more::core::result::Result::Ok(match s.to_lowercase().as_str() { |
| 184 | + #( #match_arms )* |
| 185 | + _ => return derive_more::core::result::Result::Err( |
| 186 | + derive_more::FromStrError::new(#ty_name), |
| 187 | + ), |
| 188 | + }) |
| 189 | + } |
| 190 | + } |
| 191 | + }.to_tokens(tokens); |
| 192 | + } |
113 | 193 | }
|
0 commit comments