diff --git a/Cargo.lock b/Cargo.lock index e14f6ac6..3197a703 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,6 +104,15 @@ dependencies = [ "pinocchio-pubkey", ] +[[package]] +name = "pinocchio_err" +version = "0.1.0" +dependencies = [ + "pinocchio", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.89" diff --git a/Cargo.toml b/Cargo.toml index 06e87351..46bc2145 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ members = [ "programs/memo", "programs/system", "programs/token", - "programs/token-2022", + "programs/token-2022", "sdk/err", "sdk/log/crate", "sdk/log/macro", "sdk/pinocchio", diff --git a/sdk/err/Cargo.toml b/sdk/err/Cargo.toml new file mode 100644 index 00000000..bd1ffd13 --- /dev/null +++ b/sdk/err/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "pinocchio_err" +version = "0.1.0" +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[lib] +proc_macro = true + +[dependencies] +pinocchio.workspace = true +quote.workspace = true +syn.workspace = true diff --git a/sdk/err/src/lib.rs b/sdk/err/src/lib.rs new file mode 100644 index 00000000..17b07f20 --- /dev/null +++ b/sdk/err/src/lib.rs @@ -0,0 +1,62 @@ +use pinocchio::program_error::ProgramError; +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Data, DeriveInput}; + +/// Defines a custom error macro to create custom errors in pinocchio framework +#[proc_macro_derive(ErrorCode, attributes(msg))] +pub fn error_code(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let name = &input.ident; + let variants = if let Data::Enum(data_enum) = &input.data { + &data_enum.variants + } else { + return syn::Error::new_spanned(name, "ErrorCode can only be derived for enums") + .to_compile_error() + .into(); + }; + + let match_arms = variants.iter().map(|variant| { + let variant_name = &variant.ident; + let msg = variant + .attrs + .iter() + .find(|attr| attr.path.is_ident("msg")) + .and_then(|attr| attr.parse_meta().ok()) + .and_then(|meta| { + if let syn::Meta::NameValue(nv) = meta { + if let syn::Lit::Str(lit) = nv.lit { + return Some(lit); + } + } + None + }) + .unwrap_or_else(|| { + panic!( + "Variant `{}` must have a #[msg(\"...\")] attribute", + variant_name + ) + }); + quote! { + Self::#variant_name => #msg, + } + }); + + let expanded = quote! { + impl #name { + pub fn message(&self) -> &str { + match self { + #( #match_arms )* + } + } + } + + impl From<#name> for ProgramError { + fn from(e: #name) -> Self { + ProgramError::Custom(e as u32) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/sdk/pinocchio/src/program_error.rs b/sdk/pinocchio/src/program_error.rs index d086f04e..0cb68d5a 100644 --- a/sdk/pinocchio/src/program_error.rs +++ b/sdk/pinocchio/src/program_error.rs @@ -284,6 +284,38 @@ impl ToStr for ProgramError { } } +/// Require function to check for invariants and error out using custom errors defined in the pinocchio-err crate +/// Logs the message defined in the msg attribute of the custom error's variants before erroring out +/// +/// *Arguments* : +/// +/// `invariant`: The invariant that shouldn't be false +/// +/// `error`: A variant of the custom error defined in your program, which derives the ErrorCode trait +/// +/// **`Note`**: **This function cannot be used with the ProgramError enum, as it's only optimized for custom errors for now** +#[macro_export] +macro_rules! require { + ($invariant:expr, $error:expr $(,)?) => { + if !$invariant { + // If the error type has a `message()` function, call it. + // Otherwise, do nothing. + #[allow(unused_variables)] + { + if false { + // just to scope-check type + } else { + #[allow(unused_unsafe)] + unsafe { + $crate::msg!($error.message()); + } + } + } + return Err($error.into()); + } + }; +} + #[cfg(feature = "std")] impl core::fmt::Display for ProgramError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {