Skip to content

Commit ca94bf8

Browse files
committed
Custom handling of HashEntry
- Custom Serialization to match CDDL specification. - Custom `CoseAlgorithm` serialization to properly serialize/deserialize as an i64. Signed-off-by: Larry Dewey <[email protected]>
1 parent 0f0c5d3 commit ca94bf8

File tree

1 file changed

+199
-8
lines changed

1 file changed

+199
-8
lines changed

src/core.rs

Lines changed: 199 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use std::{
4444
};
4545

4646
use derive_more::{AsMut, AsRef, Constructor, From, TryFrom};
47-
use serde::{Deserialize, Serialize};
47+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
4848

4949
use crate::{generate_tagged, FixedBytes};
5050

@@ -230,18 +230,81 @@ pub enum TextOrBytesSized<'a, const N: usize> {
230230

231231
/// Represents a hash entry with algorithm ID and hash value
232232
#[repr(C)]
233-
#[derive(
234-
Debug, Serialize, Deserialize, From, Constructor, PartialEq, Eq, PartialOrd, Ord, Clone,
235-
)]
233+
#[derive(Debug, From, Constructor, PartialEq, Eq, PartialOrd, Ord, Clone)]
236234
pub struct HashEntry {
237235
/// Algorithm identifier for the hash
238-
#[serde(rename = "hash-alg-id")]
239236
pub hash_alg_id: CoseAlgorithm,
240237
/// The hash value as bytes
241-
#[serde(rename = "hash-value")]
242238
pub hash_value: Bytes,
243239
}
244240

241+
impl Serialize for HashEntry {
242+
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
243+
where
244+
S: Serializer,
245+
{
246+
use serde::ser::SerializeSeq;
247+
// The total length is 1 (for hash_alg_id) plus the number of bytes in hash_value
248+
let len = 1 + self.hash_value.len();
249+
let mut seq = serializer.serialize_seq(Some(len))?;
250+
251+
// Serialize hash_alg_id first
252+
seq.serialize_element(&self.hash_alg_id)?;
253+
254+
// Serialize each byte in hash_value individually
255+
for byte in &self.hash_value {
256+
seq.serialize_element(byte)?;
257+
}
258+
259+
seq.end()
260+
}
261+
}
262+
263+
impl<'de> Deserialize<'de> for HashEntry {
264+
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
265+
where
266+
D: Deserializer<'de>,
267+
{
268+
use serde::de::{SeqAccess, Visitor};
269+
use std::fmt;
270+
271+
struct HashEntryVisitor;
272+
273+
impl<'de> Visitor<'de> for HashEntryVisitor {
274+
type Value = HashEntry;
275+
276+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
277+
formatter.write_str(
278+
"a sequence with at least one element (hash_alg_id followed by bytes)",
279+
)
280+
}
281+
282+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
283+
where
284+
A: SeqAccess<'de>,
285+
{
286+
// Get the first element (hash_alg_id)
287+
let hash_alg_id = seq
288+
.next_element::<CoseAlgorithm>()?
289+
.ok_or_else(|| serde::de::Error::custom("missing hash_alg_id"))?;
290+
291+
// Collect the remaining elements as bytes
292+
let mut bytes = Vec::new();
293+
while let Some(byte) = seq.next_element::<u8>()? {
294+
bytes.push(byte);
295+
}
296+
297+
Ok(HashEntry {
298+
hash_alg_id,
299+
hash_value: bytes,
300+
})
301+
}
302+
}
303+
304+
deserializer.deserialize_seq(HashEntryVisitor)
305+
}
306+
}
307+
245308
/// Represents a label that can be either text or integer
246309
#[repr(C)]
247310
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, From, TryFrom)]
@@ -495,9 +558,8 @@ pub enum VersionScheme {
495558
/// let alg = CoseAlgorithm::ES256; // ECDSA with SHA-256
496559
/// let hash_alg = CoseAlgorithm::Sha256; // SHA-256 hash function
497560
/// ```
498-
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
561+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
499562
#[repr(i64)]
500-
#[serde(untagged)]
501563
pub enum CoseAlgorithm {
502564
/// Reserved for private use (-65536)
503565
Unassigned0 = -65536,
@@ -660,3 +722,132 @@ pub enum CoseAlgorithm {
660722
/// For generating IVs (Initialization Vectors)
661723
IvGeneration = 34,
662724
}
725+
726+
impl Serialize for CoseAlgorithm {
727+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
728+
where
729+
S: Serializer,
730+
{
731+
serializer.serialize_i64(self.clone() as i64)
732+
}
733+
}
734+
impl<'de> Deserialize<'de> for CoseAlgorithm {
735+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
736+
where
737+
D: Deserializer<'de>,
738+
{
739+
// Deserialize the value as an i64
740+
let value = i64::deserialize(deserializer)?;
741+
742+
// Match the i64 value to the corresponding enum variant
743+
match value {
744+
-65536 => Ok(CoseAlgorithm::Unassigned0),
745+
-65535 => Ok(CoseAlgorithm::RS1),
746+
-65534 => Ok(CoseAlgorithm::A128CTR),
747+
-65533 => Ok(CoseAlgorithm::A192CTR),
748+
-65532 => Ok(CoseAlgorithm::A256CTR),
749+
-65531 => Ok(CoseAlgorithm::A128CBC),
750+
-65530 => Ok(CoseAlgorithm::A192CBC),
751+
-65529 => Ok(CoseAlgorithm::A256CBC),
752+
-65528 => Ok(CoseAlgorithm::Unassigned1),
753+
-260 => Ok(CoseAlgorithm::WalnutDSA),
754+
-259 => Ok(CoseAlgorithm::RS512),
755+
-258 => Ok(CoseAlgorithm::RS384),
756+
-257 => Ok(CoseAlgorithm::RS256),
757+
-256 => Ok(CoseAlgorithm::Unassigned2),
758+
-47 => Ok(CoseAlgorithm::ES256K),
759+
-46 => Ok(CoseAlgorithm::HssLms),
760+
-45 => Ok(CoseAlgorithm::SHAKE256),
761+
-44 => Ok(CoseAlgorithm::Sha512),
762+
-43 => Ok(CoseAlgorithm::Sha384),
763+
-42 => Ok(CoseAlgorithm::RsaesOaepSha512),
764+
-41 => Ok(CoseAlgorithm::RsaesOaepSha256),
765+
8017 => Ok(CoseAlgorithm::RsaesOaepRfc),
766+
-39 => Ok(CoseAlgorithm::PS512),
767+
-38 => Ok(CoseAlgorithm::PS384),
768+
-37 => Ok(CoseAlgorithm::PS256),
769+
-36 => Ok(CoseAlgorithm::ES512),
770+
-35 => Ok(CoseAlgorithm::ES384),
771+
-34 => Ok(CoseAlgorithm::EcdhSsA256kw),
772+
-33 => Ok(CoseAlgorithm::EcdhSsA192kw),
773+
-32 => Ok(CoseAlgorithm::EcdhSsA128kw),
774+
-31 => Ok(CoseAlgorithm::EcdhEsA256kw),
775+
-30 => Ok(CoseAlgorithm::EcdhEsA192kw),
776+
-29 => Ok(CoseAlgorithm::EcdhEsA128kw),
777+
-28 => Ok(CoseAlgorithm::EcdhSsHkdf512),
778+
-27 => Ok(CoseAlgorithm::EcdhSsHkdf256),
779+
-26 => Ok(CoseAlgorithm::EcdhEsHkdf512),
780+
-25 => Ok(CoseAlgorithm::EcdhEsHkdf256),
781+
-24 => Ok(CoseAlgorithm::Unassigned3),
782+
-18 => Ok(CoseAlgorithm::SHAKE128),
783+
-17 => Ok(CoseAlgorithm::Sha512_256),
784+
-16 => Ok(CoseAlgorithm::Sha256),
785+
-15 => Ok(CoseAlgorithm::Sha256_64),
786+
-14 => Ok(CoseAlgorithm::Sha1),
787+
-13 => Ok(CoseAlgorithm::DirectHkdfAes256),
788+
-12 => Ok(CoseAlgorithm::DirectHkdfAes128),
789+
-11 => Ok(CoseAlgorithm::DirectHkdfSha512),
790+
-10 => Ok(CoseAlgorithm::DirectHkdfSha256),
791+
-9 => Ok(CoseAlgorithm::Unassigned4),
792+
-8 => Ok(CoseAlgorithm::EdDSA),
793+
-7 => Ok(CoseAlgorithm::ES256),
794+
-6 => Ok(CoseAlgorithm::Direct),
795+
-5 => Ok(CoseAlgorithm::A256KW),
796+
-4 => Ok(CoseAlgorithm::A192KW),
797+
-3 => Ok(CoseAlgorithm::A128KW),
798+
-2 => Ok(CoseAlgorithm::Unassigned5),
799+
0 => Ok(CoseAlgorithm::Reserved),
800+
1 => Ok(CoseAlgorithm::A128GCM),
801+
2 => Ok(CoseAlgorithm::A192GCM),
802+
3 => Ok(CoseAlgorithm::A256GCM),
803+
4 => Ok(CoseAlgorithm::Hmac256_64),
804+
5 => Ok(CoseAlgorithm::Hmac256_256),
805+
6 => Ok(CoseAlgorithm::Hmac384_384),
806+
7 => Ok(CoseAlgorithm::Hmac512_512),
807+
8 => Ok(CoseAlgorithm::Unassigned6),
808+
10 => Ok(CoseAlgorithm::AesCcm16_64_128),
809+
11 => Ok(CoseAlgorithm::AesCcm16_64_256),
810+
12 => Ok(CoseAlgorithm::AesCcm64_64_128),
811+
13 => Ok(CoseAlgorithm::AesCcm64_64_256),
812+
14 => Ok(CoseAlgorithm::AesMac128_64),
813+
15 => Ok(CoseAlgorithm::AesMac256_64),
814+
16 => Ok(CoseAlgorithm::Unassigned7),
815+
24 => Ok(CoseAlgorithm::ChaCha20Poly1305),
816+
128 => Ok(CoseAlgorithm::AesMac128),
817+
256 => Ok(CoseAlgorithm::AesMac256),
818+
27 => Ok(CoseAlgorithm::Unassigned8),
819+
30 => Ok(CoseAlgorithm::AesCcm16_128_128),
820+
31 => Ok(CoseAlgorithm::AesCcm16_128_256),
821+
32 => Ok(CoseAlgorithm::AesCcm64_128_128),
822+
33 => Ok(CoseAlgorithm::AesCcm64_128_256),
823+
34 => Ok(CoseAlgorithm::IvGeneration),
824+
// If the value doesn't match any variant, return an error
825+
_ => Err(serde::de::Error::invalid_value(
826+
serde::de::Unexpected::Signed(value),
827+
&"a valid COSE algorithm identifier",
828+
)),
829+
}
830+
}
831+
}
832+
833+
#[cfg(test)]
834+
mod tests {
835+
use super::{CoseAlgorithm, HashEntry};
836+
837+
#[test]
838+
fn test_hash_entry_serialize() {
839+
let expected = [134, 1, 1, 2, 3, 4, 5];
840+
let actual: HashEntry = HashEntry {
841+
hash_alg_id: CoseAlgorithm::A128GCM,
842+
hash_value: vec![1, 2, 3, 4, 5],
843+
};
844+
845+
println!("{expected:02X?}");
846+
847+
let mut bytes: Vec<u8> = vec![];
848+
ciborium::into_writer(&actual, &mut bytes).unwrap();
849+
println!("{bytes:02X?}");
850+
851+
assert_eq!(bytes.as_slice(), expected.as_slice());
852+
}
853+
}

0 commit comments

Comments
 (0)