Skip to content

Commit d15a240

Browse files
committed
TagIdTypeChoice extensions
Implement type extensions for TagIdTypeChoice, allowing values of types other than those explicitly specified by the spec. Signed-off-by: setrofim <[email protected]>
1 parent 8ece93f commit d15a240

File tree

1 file changed

+108
-48
lines changed

1 file changed

+108
-48
lines changed

src/comid.rs

Lines changed: 108 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ use serde::{
100100
ser::SerializeMap,
101101
Deserialize, Serialize,
102102
};
103-
use std::{borrow::Cow, fmt::Display, marker::PhantomData};
103+
use std::{fmt::Display, marker::PhantomData, ops::Deref};
104104

105105
/// A tag version number represented as an unsigned integer
106106
pub type TagVersionType = Uint;
@@ -695,9 +695,23 @@ pub enum TagIdTypeChoice<'a> {
695695
Tstr(Tstr<'a>),
696696
/// UUID identifier
697697
Uuid(UuidType),
698+
/// Extensions
699+
Extension(ExtensionValue<'a>),
698700
}
699701

700702
impl TagIdTypeChoice<'_> {
703+
pub fn is_str(&self) -> bool {
704+
matches!(self, Self::Tstr(_))
705+
}
706+
707+
pub fn is_uuid(&self) -> bool {
708+
matches!(self, Self::Uuid(_))
709+
}
710+
711+
pub fn is_extension(&self) -> bool {
712+
matches!(self, Self::Extension(_))
713+
}
714+
701715
/// Returns the tag identifier as a string, if it is a text value
702716
pub fn as_str(&self) -> Option<&str> {
703717
match self {
@@ -737,6 +751,20 @@ impl TagIdTypeChoice<'_> {
737751
_ => None,
738752
}
739753
}
754+
755+
pub fn as_ref_extension(&self) -> Option<&ExtensionValue> {
756+
match self {
757+
Self::Extension(ext) => Some(ext),
758+
_ => None,
759+
}
760+
}
761+
762+
pub fn as_extension(&self) -> Option<ExtensionValue> {
763+
match self {
764+
Self::Extension(ext) => Some(ext.clone()),
765+
_ => None,
766+
}
767+
}
740768
}
741769

742770
impl<'a> From<&'a str> for TagIdTypeChoice<'a> {
@@ -767,59 +795,68 @@ impl<'de> Deserialize<'de> for TagIdTypeChoice<'_> {
767795
where
768796
D: de::Deserializer<'de>,
769797
{
770-
struct TagIdTypeChoiceVisitor<'a> {
771-
marker: PhantomData<&'a str>,
772-
}
773-
774-
impl<'de, 'a> Visitor<'de> for TagIdTypeChoiceVisitor<'a> {
775-
type Value = TagIdTypeChoice<'a>;
776-
777-
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
778-
formatter.write_str("a string or 16 bytes of a UUID")
779-
}
780-
781-
fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
782-
where
783-
E: de::Error,
784-
{
785-
self.visit_string(v.to_string())
786-
}
787-
788-
fn visit_bytes<E>(self, v: &[u8]) -> std::result::Result<Self::Value, E>
789-
where
790-
E: de::Error,
791-
{
792-
TagIdTypeChoice::try_from(v).map_err(de::Error::custom)
793-
}
798+
let is_human_readable = deserializer.is_human_readable();
794799

795-
fn visit_string<E>(self, v: String) -> std::result::Result<Self::Value, E>
796-
where
797-
E: de::Error,
798-
{
799-
match UuidType::try_from(v.as_str()) {
800+
if is_human_readable {
801+
match serde_json::Value::deserialize(deserializer)? {
802+
serde_json::Value::String(text) => match UuidType::try_from(text.as_str()) {
800803
Ok(uuid) => Ok(TagIdTypeChoice::Uuid(uuid)),
801-
Err(_) => Ok(TagIdTypeChoice::Tstr(Tstr::from(Cow::Owned::<str>(v)))),
804+
Err(_) => Ok(TagIdTypeChoice::Tstr(text.into())),
805+
},
806+
serde_json::Value::Object(map) => {
807+
if map.contains_key("tag") && map.contains_key("value") && map.len() == 2 {
808+
match &map["tag"] {
809+
serde_json::Value::Number(n) => match n.as_u64() {
810+
Some(u) => Ok(TagIdTypeChoice::Extension(ExtensionValue::Tag(
811+
u,
812+
Box::new(
813+
ExtensionValue::try_from(map["value"].clone())
814+
.map_err(de::Error::custom)?,
815+
),
816+
))),
817+
None => Err(de::Error::custom(format!(
818+
"a number must be an unsinged integer, got {n:?}"
819+
))),
820+
},
821+
v => Err(de::Error::custom(format!("invalid tag {v:?}"))),
822+
}
823+
} else {
824+
Ok(TagIdTypeChoice::Extension(
825+
ExtensionValue::try_from(serde_json::Value::Object(map))
826+
.map_err(de::Error::custom)?,
827+
))
828+
}
802829
}
830+
other => Ok(TagIdTypeChoice::Extension(
831+
other.try_into().map_err(de::Error::custom)?,
832+
)),
803833
}
804-
805-
fn visit_borrowed_str<E>(self, v: &'de str) -> std::result::Result<Self::Value, E>
806-
where
807-
E: de::Error,
808-
{
809-
self.visit_str(v)
810-
}
811-
812-
fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> std::result::Result<Self::Value, E>
813-
where
814-
E: de::Error,
815-
{
816-
self.visit_bytes(v)
834+
} else {
835+
match ciborium::Value::deserialize(deserializer)? {
836+
ciborium::Value::Text(text) => Ok(TagIdTypeChoice::Tstr(text.into())),
837+
ciborium::Value::Bytes(bytes) => Ok(TagIdTypeChoice::Uuid(
838+
UuidType::try_from(bytes.as_slice()).map_err(de::Error::custom)?,
839+
)),
840+
ciborium::Value::Tag(tag, inner) => {
841+
// Re-serializing the inner Value so that we can deserialize it
842+
// into an appropriate type, once we figure out what that is
843+
// based on the tag.
844+
let mut buf: Vec<u8> = Vec::new();
845+
ciborium::into_writer(&inner, &mut buf).unwrap();
846+
847+
Ok(TagIdTypeChoice::Extension(ExtensionValue::Tag(
848+
tag,
849+
Box::new(
850+
ExtensionValue::try_from(inner.deref().to_owned())
851+
.map_err(de::Error::custom)?,
852+
),
853+
)))
854+
}
855+
other => Ok(TagIdTypeChoice::Extension(
856+
other.try_into().map_err(de::Error::custom)?,
857+
)),
817858
}
818859
}
819-
820-
deserializer.deserialize_any(TagIdTypeChoiceVisitor {
821-
marker: PhantomData,
822-
})
823860
}
824861
}
825862

@@ -2577,4 +2614,27 @@ mod tests {
25772614

25782615
assert_eq!(actual_json, expected_json);
25792616
}
2617+
2618+
#[test]
2619+
fn test_tag_id_type_choice_serde() {
2620+
let test_cases = vec![
2621+
SerdeTestCase {
2622+
value: TagIdTypeChoice::Tstr("foo".into()),
2623+
expected_json: "\"foo\"",
2624+
expected_cbor: vec![
2625+
0x63, // tstr(3)
2626+
0x66, 0x6f, 0x6f, // "foo"
2627+
],
2628+
},
2629+
SerdeTestCase {
2630+
value: TagIdTypeChoice::Extension(true.into()),
2631+
expected_json: "true",
2632+
expected_cbor: vec![0xf5],
2633+
},
2634+
];
2635+
2636+
for tc in test_cases.into_iter() {
2637+
tc.run()
2638+
}
2639+
}
25802640
}

0 commit comments

Comments
 (0)