Skip to content

Commit 1d09bb0

Browse files
authored
Support named enum variants (#45)
* Support named enum variants * Don't do default impl for named enums * Bug * bug * bug * bug * bug * bug * Fix cargo fmt
1 parent b2ebabc commit 1d09bb0

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

crates/anchor-idl/src/typedef.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,21 @@ pub fn get_type_properties(defs: &[IdlTypeDef], ty: &IdlType) -> FieldListProper
137137
}
138138
}
139139

140+
/// Generates struct fields from a list of [IdlField]s.
141+
pub fn generate_enum_fields(fields: &[IdlField]) -> TokenStream {
142+
let fields_rendered = fields.iter().map(|arg| {
143+
let name = format_ident!("{}", arg.name.to_snake_case());
144+
let type_name = crate::ty_to_rust_type(&arg.ty);
145+
let stream: proc_macro2::TokenStream = type_name.parse().unwrap();
146+
quote! {
147+
#name: #stream
148+
}
149+
});
150+
quote! {
151+
#(#fields_rendered),*
152+
}
153+
}
154+
140155
/// Generates a struct.
141156
pub fn generate_struct(
142157
defs: &[IdlTypeDef],
@@ -199,7 +214,24 @@ pub fn generate_enum(
199214
enum_name: &Ident,
200215
variants: &[IdlEnumVariant],
201216
) -> TokenStream {
202-
let variant_idents = variants.iter().map(|v| format_ident!("{}", v.name));
217+
let variant_idents = variants.iter().map(|v| {
218+
let name = format_ident!("{}", v.name);
219+
match &v.fields {
220+
Some(EnumFields::Named(idl_fields)) => {
221+
let fields = generate_enum_fields(idl_fields);
222+
quote! {
223+
#name {
224+
#fields
225+
}
226+
}
227+
}
228+
_ => {
229+
quote! {
230+
#name
231+
}
232+
}
233+
}
234+
});
203235
let props = get_variant_list_properties(defs, variants);
204236

205237
let derive_copy = if props.can_copy {
@@ -210,7 +242,24 @@ pub fn generate_enum(
210242
quote! {}
211243
};
212244

213-
let default_variant = format_ident!("{}", variants.first().unwrap().name);
245+
let default_impl = match variants.first() {
246+
Some(IdlEnumVariant {
247+
fields: Some(EnumFields::Named(fields)),
248+
..
249+
}) if fields.len() > 0 => {
250+
quote! {}
251+
}
252+
_ => {
253+
let default_variant = format_ident!("{}", variants.first().unwrap().name);
254+
quote! {
255+
impl Default for #enum_name {
256+
fn default() -> Self {
257+
Self::#default_variant
258+
}
259+
}
260+
}
261+
}
262+
};
214263

215264
quote! {
216265
#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)]
@@ -219,11 +268,7 @@ pub fn generate_enum(
219268
#(#variant_idents),*
220269
}
221270

222-
impl Default for #enum_name {
223-
fn default() -> Self {
224-
Self::#default_variant
225-
}
226-
}
271+
#default_impl
227272
}
228273
}
229274

0 commit comments

Comments
 (0)