diff --git a/der/src/lib.rs b/der/src/lib.rs index bb1e7f020..88e7d8dcc 100644 --- a/der/src/lib.rs +++ b/der/src/lib.rs @@ -385,7 +385,7 @@ pub use crate::{ pub use crate::{asn1::Any, document::Document}; #[cfg(feature = "derive")] -pub use der_derive::{BitString, Choice, Enumerated, Sequence, ValueOrd}; +pub use der_derive::{BitString, Choice, DecodeValue, EncodeValue, Enumerated, Sequence, ValueOrd}; #[cfg(feature = "flagset")] pub use flagset; diff --git a/der/tests/derive.rs b/der/tests/derive.rs index d7a8fc87c..3d00fe69d 100644 --- a/der/tests/derive.rs +++ b/der/tests/derive.rs @@ -627,6 +627,56 @@ mod sequence { } } +/// Custom derive test cases for the `EncodeValue` macro. +mod encode_value { + use der::{Encode, EncodeValue, FixedTag, Tag}; + use hex_literal::hex; + + #[derive(EncodeValue, Default, Eq, PartialEq, Debug)] + #[asn1(tag_mode = "IMPLICIT")] + pub struct EncodeOnlyCheck<'a> { + #[asn1(type = "OCTET STRING", context_specific = "5")] + pub field: &'a [u8], + } + impl FixedTag for EncodeOnlyCheck<'_> { + const TAG: Tag = Tag::Sequence; + } + + #[test] + fn sequence_encode_only_to_der() { + let obj = EncodeOnlyCheck { + field: &[0x33, 0x44], + }; + + let der_encoded = obj.to_der().unwrap(); + + assert_eq!(der_encoded, hex!("30 04 85 02 33 44")); + } +} + +/// Custom derive test cases for the `DecodeValue` macro. +mod decode_value { + use der::{Decode, DecodeValue, FixedTag, Tag}; + use hex_literal::hex; + + #[derive(DecodeValue, Default, Eq, PartialEq, Debug)] + #[asn1(tag_mode = "IMPLICIT")] + pub struct DecodeOnlyCheck<'a> { + #[asn1(type = "OCTET STRING", context_specific = "5")] + pub field: &'a [u8], + } + impl FixedTag for DecodeOnlyCheck<'_> { + const TAG: Tag = Tag::Sequence; + } + + #[test] + fn sequence_decode_only_from_der() { + let obj = DecodeOnlyCheck::from_der(&hex!("30 04 85 02 33 44")).unwrap(); + + assert_eq!(obj.field, &[0x33, 0x44]); + } +} + /// Custom derive test cases for the `BitString` macro. #[cfg(feature = "std")] mod bitstring { diff --git a/der_derive/src/lib.rs b/der_derive/src/lib.rs index 6775f4a54..d8a5c0f82 100644 --- a/der_derive/src/lib.rs +++ b/der_derive/src/lib.rs @@ -261,7 +261,7 @@ pub fn derive_enumerated(input: TokenStream) -> TokenStream { } } -/// Derive the [`Sequence`][1] trait on a `struct`. +/// Derive the [`DecodeValue`][1], [`EncodeValue`][2], [`Sequence`][3] traits on a `struct`. /// /// This custom derive macro can be used to automatically impl the /// `Sequence` trait for any struct which can be decoded/encoded as an @@ -289,16 +289,42 @@ pub fn derive_enumerated(input: TokenStream) -> TokenStream { /// /// # `#[asn1(type = "...")]` attribute /// -/// See [toplevel documentation for the `der_derive` crate][2] for more +/// See [toplevel documentation for the `der_derive` crate][4] for more /// information about the `#[asn1]` attribute. /// -/// [1]: https://docs.rs/der/latest/der/trait.Sequence.html -/// [2]: https://docs.rs/der_derive/ +/// [1]: https://docs.rs/der/latest/der/trait.DecodeValue.html +/// [2]: https://docs.rs/der/latest/der/trait.EncodeValue.html +/// [3]: https://docs.rs/der/latest/der/trait.Sequence.html +/// [4]: https://docs.rs/der_derive/ #[proc_macro_derive(Sequence, attributes(asn1))] pub fn derive_sequence(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); match DeriveSequence::new(input) { - Ok(t) => t.to_tokens().into(), + Ok(t) => t.to_tokens_all().into(), + Err(e) => e.to_compile_error().into(), + } +} + +/// Derive the [`EncodeValue`][1] trait on a `struct`. +/// +/// [1]: https://docs.rs/der/latest/der/trait.EncodeValue.html +#[proc_macro_derive(EncodeValue, attributes(asn1))] +pub fn derive_sequence_encode(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match DeriveSequence::new(input) { + Ok(t) => t.to_tokens_encode().into(), + Err(e) => e.to_compile_error().into(), + } +} + +/// Derive the [`DecodeValue`][1] trait on a `struct`. +/// +/// [1]: https://docs.rs/der/latest/der/trait.DecodeValue.html +#[proc_macro_derive(DecodeValue, attributes(asn1))] +pub fn derive_sequence_decode(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match DeriveSequence::new(input) { + Ok(t) => t.to_tokens_decode().into(), Err(e) => e.to_compile_error().into(), } } diff --git a/der_derive/src/sequence.rs b/der_derive/src/sequence.rs index c15ee267e..963db4f8f 100644 --- a/der_derive/src/sequence.rs +++ b/der_derive/src/sequence.rs @@ -7,7 +7,7 @@ use crate::{ErrorType, TypeAttrs, default_lifetime}; use field::SequenceField; use proc_macro2::TokenStream; use quote::{ToTokens, quote}; -use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; +use syn::{DeriveInput, GenericParam, Generics, Ident, Lifetime, LifetimeParam}; /// Derive the `Sequence` trait for a struct pub(crate) struct DeriveSequence { @@ -51,13 +51,10 @@ impl DeriveSequence { }) } - /// Lower the derived output into a [`TokenStream`]. - pub fn to_tokens(&self) -> TokenStream { - let ident = &self.ident; + /// Use the first lifetime parameter as lifetime for Decode/Encode lifetime + /// if none found, add one. + fn calc_lifetime(&self) -> (Generics, Lifetime) { let mut generics = self.generics.clone(); - - // Use the first lifetime parameter as lifetime for Decode/Encode lifetime - // if none found, add one. let lifetime = generics .lifetimes() .next() @@ -69,23 +66,39 @@ impl DeriveSequence { .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone()))); lt }); - // We may or may not have inserted a lifetime. + (generics, lifetime) + } + + /// Lower the derived output into a [`TokenStream`] for Sequence trait impl. + pub fn to_tokens_sequence_trait(&self) -> TokenStream { + let ident = &self.ident; + + let (der_generics, lifetime) = self.calc_lifetime(); + let (_, ty_generics, where_clause) = self.generics.split_for_impl(); - let (impl_generics, _, _) = generics.split_for_impl(); + let (impl_generics, _, _) = der_generics.split_for_impl(); + + quote! { + impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {} + } + } + + /// Lower the derived output into a [`TokenStream`] for DecodeValue trait impl. + pub fn to_tokens_decode(&self) -> TokenStream { + let ident = &self.ident; + + let (der_generics, lifetime) = self.calc_lifetime(); + + let (_, ty_generics, where_clause) = self.generics.split_for_impl(); + let (impl_generics, _, _) = der_generics.split_for_impl(); let mut decode_body = Vec::new(); let mut decode_result = Vec::new(); - let mut encoded_lengths = Vec::new(); - let mut encode_fields = Vec::new(); for field in &self.fields { decode_body.push(field.to_decode_tokens()); decode_result.push(&field.ident); - - let field = field.to_encode_tokens(); - encoded_lengths.push(quote!(#field.encoded_len()?)); - encode_fields.push(quote!(#field.encode(writer)?;)); } let error = self.error.to_token_stream(); @@ -109,6 +122,26 @@ impl DeriveSequence { }) } } + } + } + + /// Lower the derived output into a [`TokenStream`] for EncodeValue trait impl. + pub fn to_tokens_encode(&self) -> TokenStream { + let ident = &self.ident; + + let (_, ty_generics, where_clause) = self.generics.split_for_impl(); + let (impl_generics, _, _) = self.generics.split_for_impl(); + + let mut encoded_lengths = Vec::new(); + let mut encode_fields = Vec::new(); + + for field in &self.fields { + let field = field.to_encode_tokens(); + encoded_lengths.push(quote!(#field.encoded_len()?)); + encode_fields.push(quote!(#field.encode(writer)?;)); + } + + quote! { impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause { fn value_len(&self) -> ::der::Result<::der::Length> { @@ -127,8 +160,22 @@ impl DeriveSequence { Ok(()) } } + } + } - impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {} + /// Lower the derived output into a [`TokenStream`] for trait impls: + /// - EncodeValue + /// - DecodeValue + /// - Sequence + pub fn to_tokens_all(&self) -> TokenStream { + let decode_tokens = self.to_tokens_decode(); + let encode_tokens = self.to_tokens_encode(); + let sequence_trait_tokens = self.to_tokens_sequence_trait(); + + quote! { + #decode_tokens + #encode_tokens + #sequence_trait_tokens } } }