Skip to content

Commit ee71a8d

Browse files
committed
idl: Support NonZero<num> in IDL
1 parent 1ebbe58 commit ee71a8d

File tree

10 files changed

+151
-1
lines changed

10 files changed

+151
-1
lines changed

cli/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2899,6 +2899,7 @@ fn deserialize_idl_type_to_json(
28992899
deserialize_idl_type_to_json(ty, data, parent_idl)?
29002900
}
29012901
}
2902+
IdlType::NonZero(ty) => deserialize_idl_type_to_json(ty, data, parent_idl)?,
29022903
IdlType::Vec(ty) => {
29032904
let size: usize = <u32 as AnchorDeserialize>::deserialize(data)?
29042905
.try_into()

idl/spec/src/lib.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ pub enum IdlType {
306306
#[serde(default, skip_serializing_if = "is_default")]
307307
generics: Vec<IdlGenericArg>,
308308
},
309+
NonZero(Box<IdlType>),
309310
Generic(String),
310311
}
311312

@@ -354,6 +355,20 @@ impl FromStr for IdlType {
354355
)?;
355356
return Ok(IdlType::Vec(Box::new(inner_ty)));
356357
}
358+
// NonZero<u8>
359+
if let Some(inner) = s.strip_prefix("NonZero<") {
360+
let inner_ty = Self::from_str(
361+
inner
362+
.strip_suffix('>')
363+
.ok_or_else(|| anyhow!("Invalid NonZero"))?,
364+
)?;
365+
return Ok(IdlType::NonZero(Box::new(inner_ty)));
366+
}
367+
// NonZeroU8
368+
if let Some(inner) = s.strip_prefix("NonZero") {
369+
let inner_ty = Self::from_str(&inner.to_lowercase())?;
370+
return Ok(IdlType::NonZero(Box::new(inner_ty)));
371+
}
357372

358373
if s.starts_with('[') {
359374
fn array_from_str(inner: &str) -> IdlType {
@@ -499,4 +514,16 @@ mod tests {
499514
}
500515
)
501516
}
517+
518+
#[test]
519+
fn nonzero() {
520+
assert_eq!(
521+
IdlType::from_str("NonZero<u64>").unwrap(),
522+
IdlType::NonZero(Box::new(IdlType::U64)),
523+
);
524+
assert_eq!(
525+
IdlType::from_str("NonZeroU8").unwrap(),
526+
IdlType::NonZero(Box::new(IdlType::U8)),
527+
);
528+
}
502529
}

idl/src/convert.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ mod legacy {
244244
Array(Box<IdlType>, usize),
245245
GenericLenArray(Box<IdlType>, String),
246246
Generic(String),
247+
NonZero(Box<IdlType>),
247248
DefinedWithTypeArgs {
248249
name: String,
249250
args: Vec<IdlDefinedTypeArg>,
@@ -491,6 +492,7 @@ mod legacy {
491492
IdlType::GenericLenArray(ty, generic) => {
492493
t::IdlType::Array(ty.into(), t::IdlArrayLen::Generic(generic))
493494
}
495+
IdlType::NonZero(ty) => t::IdlType::NonZero(ty.into()),
494496
_ => serde_json::to_value(value)
495497
.and_then(serde_json::from_value)
496498
.unwrap(),

lang/attribute/program/src/declare_program/common.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ pub fn convert_idl_type_to_str(ty: &IdlType) -> String {
8989
.map(|generics| format!("{name}<{generics}>"))
9090
.unwrap_or(name.into()),
9191
IdlType::Generic(ty) => ty.into(),
92+
IdlType::NonZero(ty) => format!("NonZero<{}>", convert_idl_type_to_str(ty)),
9293
_ => unimplemented!("{ty:?}"),
9394
}
9495
}
@@ -309,6 +310,7 @@ fn can_derive_copy_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
309310
.map(|ty_def| can_derive_copy(ty_def, ty_defs))
310311
.expect("Type def must exist"),
311312
IdlType::Bytes | IdlType::String | IdlType::Vec(_) | IdlType::Generic(_) => false,
313+
IdlType::NonZero(inner) => can_derive_copy_ty(inner, ty_defs),
312314
_ => true,
313315
}
314316
}
@@ -333,6 +335,7 @@ fn can_derive_default_ty(ty: &IdlType, ty_defs: &[IdlTypeDef]) -> bool {
333335
.map(|ty_def| can_derive_default(ty_def, ty_defs))
334336
.expect("Type def must exist"),
335337
IdlType::Generic(_) => false,
338+
IdlType::NonZero(_) => false,
336339
_ => true,
337340
}
338341
}

lang/syn/src/idl/defined.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ pub fn gen_idl_type(
380380
syn::Type::Path(path) if the_only_segment_is(path, "bool") => {
381381
Ok((quote! { #idl::IdlType::Bool }, vec![]))
382382
}
383+
// Integer types
383384
syn::Type::Path(path) if the_only_segment_is(path, "u8") => {
384385
Ok((quote! { #idl::IdlType::U8 }, vec![]))
385386
}
@@ -416,6 +417,56 @@ pub fn gen_idl_type(
416417
syn::Type::Path(path) if the_only_segment_is(path, "i128") => {
417418
Ok((quote! { #idl::IdlType::I128 }, vec![]))
418419
}
420+
// Non-zero integer types
421+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroU8") => Ok((
422+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::U8)) },
423+
vec![],
424+
)),
425+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroI8") => Ok((
426+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::I8)) },
427+
vec![],
428+
)),
429+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroU16") => Ok((
430+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::U16)) },
431+
vec![],
432+
)),
433+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroI16") => Ok((
434+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::I16)) },
435+
vec![],
436+
)),
437+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroU32") => Ok((
438+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::U32)) },
439+
vec![],
440+
)),
441+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroI32") => Ok((
442+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::I32)) },
443+
vec![],
444+
)),
445+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroU64") => Ok((
446+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::U64)) },
447+
vec![],
448+
)),
449+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroI64") => Ok((
450+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::I64)) },
451+
vec![],
452+
)),
453+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroU128") => Ok((
454+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::U128)) },
455+
vec![],
456+
)),
457+
syn::Type::Path(path) if the_only_segment_is(path, "NonZeroI128") => Ok((
458+
quote! { #idl::IdlType::NonZero(Box::new(#idl::IdlType::I128)) },
459+
vec![],
460+
)),
461+
syn::Type::Path(path) if the_only_segment_is(path, "NonZero") => {
462+
let segment = get_first_segment(path);
463+
let arg = get_angle_bracketed_type_args(segment)
464+
.into_iter()
465+
.next()
466+
.unwrap();
467+
let (inner, defined) = gen_idl_type(arg, generic_params)?;
468+
Ok((quote! { #idl::IdlType::NonZero(Box::new(#inner)) }, defined))
469+
}
419470
syn::Type::Path(path)
420471
if the_only_segment_is(path, "String") || the_only_segment_is(path, "str") =>
421472
{

tests/idl/idls/new.json

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@
299299
"name": "FooEnum"
300300
}
301301
}
302+
},
303+
{
304+
"name": "non_zero_u8",
305+
"type": {
306+
"nonzero": "u8"
307+
}
308+
},
309+
{
310+
"name": "non_zero_u64",
311+
"type": {
312+
"nonzero": "u64"
313+
}
302314
}
303315
]
304316
},
@@ -811,6 +823,18 @@
811823
"name": "FooEnum"
812824
}
813825
}
826+
},
827+
{
828+
"name": "non_zero_u8",
829+
"type": {
830+
"nonzero": "u8"
831+
}
832+
},
833+
{
834+
"name": "non_zero_u64",
835+
"type": {
836+
"nonzero": "u64"
837+
}
814838
}
815839
]
816840
}
@@ -898,4 +922,4 @@
898922
"value": "6"
899923
}
900924
]
901-
}
925+
}

tests/idl/programs/idl/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use anchor_lang::prelude::*;
22
use anchor_spl::{token, token_interface};
3+
use std::num::{NonZero, NonZeroU64};
34

45
declare_id!("id11111111111111111111111111111111111111111");
56

@@ -56,6 +57,8 @@ pub mod idl {
5657
enum_field_2: FooEnum,
5758
enum_field_3: FooEnum,
5859
enum_field_4: FooEnum,
60+
non_zero_u8: NonZero<u8>,
61+
non_zero_u64: NonZeroU64,
5962
) -> Result<()> {
6063
ctx.accounts.state.set_inner(State {
6164
bool_field,
@@ -84,6 +87,8 @@ pub mod idl {
8487
enum_field_2,
8588
enum_field_3,
8689
enum_field_4,
90+
non_zero_u8,
91+
non_zero_u64,
8792
});
8893

8994
Ok(())
@@ -200,6 +205,8 @@ pub struct State {
200205
enum_field_2: FooEnum,
201206
enum_field_3: FooEnum,
202207
enum_field_4: FooEnum,
208+
non_zero_u8: NonZero<u8>,
209+
non_zero_u64: NonZeroU64,
203210
}
204211

205212
impl Default for State {
@@ -235,6 +242,8 @@ impl Default for State {
235242
},
236243
enum_field_3: FooEnum::Struct(BarStruct::default()),
237244
enum_field_4: FooEnum::NoFields,
245+
non_zero_u8: NonZero::new(1).unwrap(),
246+
non_zero_u64: NonZeroU64::new(1).unwrap(),
238247
}
239248
}
240249
}

ts/packages/anchor/src/coder/borsh/idl.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ export class IdlCoder {
7676
return borsh.publicKey(fieldName);
7777
}
7878
default: {
79+
if ("nonzero" in field.type) {
80+
return borsh.nonzero(
81+
IdlCoder.fieldLayout(
82+
{ type: field.type.nonzero },
83+
types,
84+
genericArgs,
85+
),
86+
fieldName
87+
);
88+
}
7989
if ("option" in field.type) {
8090
return borsh.option(
8191
IdlCoder.fieldLayout(

ts/packages/anchor/src/idl.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ export type IdlType =
231231
| "pubkey"
232232
| IdlTypeOption
233233
| IdlTypeCOption
234+
| IdlTypeNonZero
234235
| IdlTypeVec
235236
| IdlTypeArray
236237
| IdlTypeDefined
@@ -244,6 +245,10 @@ export type IdlTypeCOption = {
244245
coption: IdlType;
245246
};
246247

248+
export type IdlTypeNonZero = {
249+
nonzero: IdlType;
250+
};
251+
247252
export type IdlTypeVec = {
248253
vec: IdlType;
249254
};

ts/packages/borsh/src/index.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,24 @@ export function option<T>(
175175
return new OptionLayout<T>(layout, property);
176176
}
177177

178+
function decodeNonZero<T>(value: T): T {
179+
if (value === 0) {
180+
throw new Error("Invalid nonzero: " + value);
181+
}
182+
return value;
183+
}
184+
185+
function encodeNonZero<T>(value: T): T {
186+
return value;
187+
}
188+
189+
export function nonzero<T>(
190+
layout: Layout<T>,
191+
property?: string
192+
): Layout<T> {
193+
return new WrappedLayout(layout, decodeNonZero, encodeNonZero, property);
194+
}
195+
178196
export function bool(property?: string): Layout<boolean> {
179197
return new WrappedLayout(u8(), decodeBool, encodeBool, property);
180198
}

0 commit comments

Comments
 (0)