diff --git a/source/postcard-dyn/Cargo.toml b/source/postcard-dyn/Cargo.toml index acbab013..47cd057b 100644 --- a/source/postcard-dyn/Cargo.toml +++ b/source/postcard-dyn/Cargo.toml @@ -19,7 +19,9 @@ documentation = "https://docs.rs/postcard-dyn/" [dependencies] +hashbrown = { version = "0.15.2", default-features = false, features = ["default-hasher"] } serde = { version = "1.0.202", features = ["derive"] } +serde-content = "0.1.0" serde_json = "1.0.117" [dependencies.postcard] @@ -31,3 +33,7 @@ path = "../postcard" version = "0.2" features = ["use-std", "derive"] path = "../postcard-schema" + +[dev-dependencies.serde_json] +version = "1.0" +features = ["preserve_order"] diff --git a/source/postcard-dyn/src/de.rs b/source/postcard-dyn/src/de.rs index 7e532b1f..ff566bd2 100644 --- a/source/postcard-dyn/src/de.rs +++ b/source/postcard-dyn/src/de.rs @@ -1,454 +1,18 @@ -use std::str::from_utf8; - -use postcard_schema::schema::owned::{OwnedData, OwnedDataModelType}; -use serde_json::{Map, Number, Value}; - -use crate::de::varint::de_zig_zag_i16; - -use self::varint::{ - de_zig_zag_i128, de_zig_zag_i32, de_zig_zag_i64, try_take_varint_u128, try_take_varint_u16, - try_take_varint_u32, try_take_varint_u64, try_take_varint_usize, -}; - -#[derive(Debug, PartialEq)] -pub enum Error { - UnexpectedEndOfData, - ShouldSupportButDont, - SchemaMismatch, -} - -trait GetExt { - type Out; - fn right(self) -> Result; -} - -impl GetExt for Option { - type Out = T; - - fn right(self) -> Result { - self.ok_or(Error::SchemaMismatch) - } -} - -pub fn from_slice_dyn(schema: &OwnedDataModelType, data: &[u8]) -> Result { - let (val, _remain) = deserialize(schema, data)?; - Ok(val) -} - -fn deserialize<'a>(ty: &OwnedDataModelType, data: &'a [u8]) -> Result<(Value, &'a [u8]), Error> { - match ty { - OwnedDataModelType::Bool => { - let (one, rest) = data.take_one()?; - let val = match one { - 0 => Value::Bool(false), - 1 => Value::Bool(true), - _ => return Err(Error::SchemaMismatch), - }; - Ok((val, rest)) - } - OwnedDataModelType::I8 => { - let (one, rest) = data.take_one()?; - let val = Value::Number(Number::from(one as i8)); - Ok((val, rest)) - } - OwnedDataModelType::U8 => { - let (one, rest) = data.take_one()?; - let val = Value::Number(Number::from(one)); - Ok((val, rest)) - } - OwnedDataModelType::I16 => { - let (val, rest) = try_take_varint_u16(data)?; - let val = de_zig_zag_i16(val); - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::I32 => { - let (val, rest) = try_take_varint_u32(data)?; - let val = de_zig_zag_i32(val); - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::I64 => { - let (val, rest) = try_take_varint_u64(data)?; - let val = de_zig_zag_i64(val); - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::I128 => { - let (val, rest) = try_take_varint_u128(data)?; - let val = de_zig_zag_i128(val); - let val = i64::try_from(val).map_err(|_| Error::ShouldSupportButDont)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::U16 => { - let (val, rest) = try_take_varint_u16(data)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::U32 => { - let (val, rest) = try_take_varint_u32(data)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::U64 => { - let (val, rest) = try_take_varint_u64(data)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::U128 => { - let (val, rest) = try_take_varint_u128(data)?; - let val = u64::try_from(val).map_err(|_| Error::ShouldSupportButDont)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::Usize => { - let (val, rest) = try_take_varint_usize(data)?; - let val = Value::Number(Number::from(val)); - Ok((val, rest)) - } - OwnedDataModelType::Isize => { - let (val, rest) = try_take_varint_usize(data)?; - - #[cfg(target_pointer_width = "16")] - let valu = de_zig_zag_i16(val as u16); - - #[cfg(target_pointer_width = "32")] - let valu = de_zig_zag_i32(val as u32); - - #[cfg(target_pointer_width = "64")] - let valu = de_zig_zag_i64(val as u64); - - let valu = Value::Number(Number::from(valu)); - Ok((valu, rest)) - } - OwnedDataModelType::F32 => { - let (val, rest) = data.take_n(4)?; - let mut buf = [0u8; 4]; - buf.copy_from_slice(val); - let f = f32::from_le_bytes(buf); - let val = Value::Number(Number::from_f64(f.into()).right()?); - Ok((val, rest)) - } - OwnedDataModelType::F64 => { - let (val, rest) = data.take_n(8)?; - let mut buf = [0u8; 8]; - buf.copy_from_slice(val); - let f = f64::from_le_bytes(buf); - let val = Value::Number(Number::from_f64(f).right()?); - Ok((val, rest)) - } - OwnedDataModelType::Char => todo!(), - OwnedDataModelType::String => { - let (val, rest) = try_take_varint_usize(data)?; - let (bytes, rest) = rest.take_n(val)?; - let s = from_utf8(bytes).map_err(|_| Error::SchemaMismatch)?; - let val = Value::String(s.to_string()); - Ok((val, rest)) - } - OwnedDataModelType::ByteArray => { - let (val, rest) = try_take_varint_usize(data)?; - let (bytes, rest) = rest.take_n(val)?; - let vvec = bytes - .iter() - .map(|b| Value::Number(Number::from(*b))) - .collect::>(); - let val = Value::Array(vvec); - Ok((val, rest)) - } - OwnedDataModelType::Option(inner) => { - let (val, rest) = data.take_one()?; - match val { - 0 => return Ok((Value::Null, rest)), - 1 => {} - _ => return Err(Error::SchemaMismatch), - } - deserialize(inner, rest) - } - OwnedDataModelType::Unit - | OwnedDataModelType::Struct { - name: _, - data: OwnedData::Unit, - } => { - // TODO This is PROBABLY wrong, as Some(()) will be coalesced into the same - // value as None. Fix this when we have our own Value - Ok((Value::Null, data)) - } - OwnedDataModelType::Struct { - name: _, - data: OwnedData::Newtype(ty), - } => deserialize(ty, data), - OwnedDataModelType::Seq(ty) => { - let (val, mut rest) = try_take_varint_usize(data)?; - let mut vec = vec![]; - for _ in 0..val { - let (v, irest) = deserialize(ty, rest)?; - rest = irest; - vec.push(v); - } - Ok((Value::Array(vec), rest)) - } - OwnedDataModelType::Tuple(tys) - | OwnedDataModelType::Struct { - name: _, - data: OwnedData::Tuple(tys), - } => { - match &tys[..] { - [] => { - // TODO: Not sure this is right... - Ok((Value::Null, data)) - } - [ty] => { - // Single item, NOT an array - deserialize(ty, data) - } - multi => { - let mut vec = vec![]; - let mut rest = data; - for ty in multi.iter() { - let (val, irest) = deserialize(ty, rest)?; - rest = irest; - vec.push(val); - } - Ok((Value::Array(vec), rest)) - } - } - } - OwnedDataModelType::Map { key, val } => { - // TODO: impling blind because we can't test this, oops - // - // TODO: There's also a mismatch here because serde_json::Value requires - // keys to be strings, when postcard doesn't. - if **key != OwnedDataModelType::String { - return Err(Error::ShouldSupportButDont); - } - - let (map_len, mut rest) = try_take_varint_usize(data)?; - let mut map = Map::new(); - - for _ in 0..map_len { - let (str_len, irest) = try_take_varint_usize(rest)?; - let (bytes, irest) = irest.take_n(str_len)?; - let s = from_utf8(bytes).map_err(|_| Error::SchemaMismatch)?; - - let (v, irest) = deserialize(val, irest)?; - rest = irest; - - map.insert(s.to_string(), v); - } - - Ok((Value::Object(map), rest)) - } - OwnedDataModelType::Struct { - name: _, - data: OwnedData::Struct(nvs), - } => { - let mut map = Map::new(); - let mut rest = data; - for nv in nvs.iter() { - let (val, irest) = deserialize(&nv.ty, rest)?; - rest = irest; - map.insert(nv.name.to_string(), val); - } - Ok((Value::Object(map), rest)) - } - OwnedDataModelType::Enum { - name: _, - variants: nvars, - } => { - let (variant, rest) = try_take_varint_usize(data)?; - let schema = nvars.get(variant).right()?; - match &schema.data { - OwnedData::Unit => { - // Units become strings - Ok((Value::String(schema.name.to_string()), rest)) - } - OwnedData::Newtype(ty) => { - // everything else becomes an object with one field - let (val, irest) = deserialize(ty, rest)?; - let mut map = Map::new(); - map.insert(schema.name.to_owned().to_string(), val); - Ok((Value::Object(map), irest)) - } - OwnedData::Tuple(vec) => { - // everything else becomes an object with one field - let (val, irest) = deserialize(&OwnedDataModelType::Tuple(vec.clone()), rest)?; - let mut map = Map::new(); - map.insert(schema.name.to_owned().to_string(), val); - Ok((Value::Object(map), irest)) - } - OwnedData::Struct(vec) => { - // everything else becomes an object with one field - let (val, irest) = deserialize( - &OwnedDataModelType::Struct { - name: schema.name.clone(), - data: OwnedData::Struct(vec.clone()), - }, - rest, - )?; - let mut map = Map::new(); - map.insert(schema.name.to_owned().to_string(), val); - Ok((Value::Object(map), irest)) - } - } - } - OwnedDataModelType::Schema => todo!(), - } -} - -mod varint { - // copy and paste from postcard - - use crate::ser::varint::varint_max; - - use super::{Error, TakeExt}; - - /// Returns the maximum value stored in the last encoded byte. - pub const fn max_of_last_byte() -> u8 { - let max_bits = core::mem::size_of::() * 8; - let extra_bits = max_bits % 7; - (1 << extra_bits) - 1 - } - - pub fn de_zig_zag_i16(n: u16) -> i16 { - ((n >> 1) as i16) ^ (-((n & 0b1) as i16)) - } - - pub fn de_zig_zag_i32(n: u32) -> i32 { - ((n >> 1) as i32) ^ (-((n & 0b1) as i32)) - } - - pub fn de_zig_zag_i64(n: u64) -> i64 { - ((n >> 1) as i64) ^ (-((n & 0b1) as i64)) - } - - pub fn de_zig_zag_i128(n: u128) -> i128 { - ((n >> 1) as i128) ^ (-((n & 0b1) as i128)) - } - - #[cfg(target_pointer_width = "16")] - #[inline(always)] - pub fn try_take_varint_usize(data: &[u8]) -> Result<(usize, &[u8]), Error> { - try_take_varint_u16(data).map(|(u, rest)| (u as usize, rest)) - } - - #[cfg(target_pointer_width = "32")] - #[inline(always)] - pub fn try_take_varint_usize(data: &[u8]) -> Result<(usize, &[u8]), Error> { - try_take_varint_u32(data).map(|(u, rest)| (u as usize, rest)) - } - - #[cfg(target_pointer_width = "64")] - #[inline(always)] - pub fn try_take_varint_usize(data: &[u8]) -> Result<(usize, &[u8]), Error> { - try_take_varint_u64(data).map(|(u, rest)| (u as usize, rest)) - } - - #[inline] - pub fn try_take_varint_u16(data: &[u8]) -> Result<(u16, &[u8]), Error> { - let mut rest = data; - let mut out = 0; - for i in 0..varint_max::() { - let (val, later) = rest.take_one()?; - rest = later; - let carry = (val & 0x7F) as u16; - out |= carry << (7 * i); - - if (val & 0x80) == 0 { - if i == varint_max::() - 1 && val > max_of_last_byte::() { - return Err(Error::SchemaMismatch); - } else { - return Ok((out, rest)); - } - } - } - Err(Error::SchemaMismatch) - } - - #[inline] - pub fn try_take_varint_u32(data: &[u8]) -> Result<(u32, &[u8]), Error> { - let mut rest = data; - let mut out = 0; - for i in 0..varint_max::() { - let (val, later) = rest.take_one()?; - rest = later; - let carry = (val & 0x7F) as u32; - out |= carry << (7 * i); - - if (val & 0x80) == 0 { - if i == varint_max::() - 1 && val > max_of_last_byte::() { - return Err(Error::SchemaMismatch); - } else { - return Ok((out, rest)); - } - } - } - Err(Error::SchemaMismatch) - } - - #[inline] - pub fn try_take_varint_u64(data: &[u8]) -> Result<(u64, &[u8]), Error> { - let mut rest = data; - let mut out = 0; - for i in 0..varint_max::() { - let (val, later) = rest.take_one()?; - rest = later; - let carry = (val & 0x7F) as u64; - out |= carry << (7 * i); - - if (val & 0x80) == 0 { - if i == varint_max::() - 1 && val > max_of_last_byte::() { - return Err(Error::SchemaMismatch); - } else { - return Ok((out, rest)); - } - } - } - Err(Error::SchemaMismatch) - } - - #[inline] - pub fn try_take_varint_u128(data: &[u8]) -> Result<(u128, &[u8]), Error> { - let mut rest = data; - let mut out = 0; - for i in 0..varint_max::() { - let (val, later) = rest.take_one()?; - rest = later; - let carry = (val & 0x7F) as u128; - out |= carry << (7 * i); - - if (val & 0x80) == 0 { - if i == varint_max::() - 1 && val > max_of_last_byte::() { - return Err(Error::SchemaMismatch); - } else { - return Ok((out, rest)); - } - } - } - Err(Error::SchemaMismatch) - } -} - -trait TakeExt { - fn take_one(&self) -> Result<(u8, &[u8]), Error>; - fn take_n(&self, n: usize) -> Result<(&[u8], &[u8]), Error>; -} - -impl TakeExt for [u8] { - fn take_one(&self) -> Result<(u8, &[u8]), Error> { - if let Some((first, rest)) = self.split_first() { - Ok((*first, rest)) - } else { - Err(Error::UnexpectedEndOfData) - } - } - - fn take_n(&self, n: usize) -> Result<(&[u8], &[u8]), Error> { - if self.len() < n { - return Err(Error::UnexpectedEndOfData); - } - Ok(self.split_at(n)) - } +use postcard_schema::schema::owned::OwnedDataModelType; +use serde_json::Value; + +use crate::Error; + +pub fn from_slice_dyn( + schema: &OwnedDataModelType, + data: &[u8], +) -> Result> { + // Matches current value type (`serde_json::Value`)'s representation + crate::reserialize::lossy::reserialize_with_structs_and_enums_as_maps( + schema, + &mut postcard::Deserializer::from_bytes(data), + serde_json::value::Serializer, + ) } #[cfg(test)] diff --git a/source/postcard-dyn/src/error.rs b/source/postcard-dyn/src/error.rs new file mode 100644 index 00000000..6c7bf1d0 --- /dev/null +++ b/source/postcard-dyn/src/error.rs @@ -0,0 +1,26 @@ +use core::fmt::{self, Display}; + +/// Errors encountered by `postcard-dyn` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Error { + Deserialize(DeserializeError), + Serialize(SerializeError), +} + +impl Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Deserialize(err) => Display::fmt(err, f), + Self::Serialize(err) => Display::fmt(err, f), + } + } +} + +impl core::error::Error for Error { + fn source(&self) -> Option<&(dyn core::error::Error + 'static)> { + match self { + Self::Deserialize(err) => err.source(), + Self::Serialize(err) => err.source(), + } + } +} diff --git a/source/postcard-dyn/src/lib.rs b/source/postcard-dyn/src/lib.rs index 2316cf9e..cf0fdd80 100644 --- a/source/postcard-dyn/src/lib.rs +++ b/source/postcard-dyn/src/lib.rs @@ -1,6 +1,10 @@ +mod error; + mod de; +pub mod reserialize; mod ser; pub use de::from_slice_dyn; +pub use error::Error; pub use ser::to_stdvec_dyn; pub use serde_json::Value; diff --git a/source/postcard-dyn/src/reserialize.rs b/source/postcard-dyn/src/reserialize.rs new file mode 100644 index 00000000..b20e133d --- /dev/null +++ b/source/postcard-dyn/src/reserialize.rs @@ -0,0 +1,423 @@ +//! Dynamically reserialize [`postcard`]-encoded values from [`Deserializer`]s into [`Serializer`]s. +//! +//! This module implements transformations between postcard-encoded data and other serialized forms +//! based on [dynamic schemas](postcard_schema::schema::owned). For example, this could be used to +//! transform postcard-encoded data to JSON or another human-readable format, or to transform JSON +//! to postcard. +//! +//! # Limitations +//! +//! Several [`Deserializer`] and [`Serializer`] methods require `&'static` parameters, namely: +//! - [`Serializer`] methods for serializing structs and enums require `&'static str`s for the +//! names of structs, fields, enums, and variants. +//! - [`Deserializer`] methods for deserializing structs and enums require the same `&'static str`s +//! as the corresponding serialize methods, and moreover `&'static [&'static str]`s for the names +//! of fields in structs. +//! +//! Since these transformations work with dynamic schemas that contain [`String`]s instead of +//! `&'static str`s, lossless reserialization is possible only with other compromises. +//! +//! In particular, reserialization can be either: +//! - Lossless with implementation compromises: see [`lossless`] +//! - Lossy with regards to structs and enums: see [`lossy`] + +use core::{cell::Cell, fmt, marker::PhantomData, slice}; + +use postcard_schema::schema::owned::{OwnedData, OwnedDataModelType}; +use serde::{de, ser::Error as _, Deserialize, Deserializer, Serialize, Serializer}; + +use crate::Error; + +pub mod lossless; +pub mod lossy; + +mod expecting; +mod strategy; +use strategy::Strategy; + +mod map; +mod option; +mod seq; +mod tuple; + +struct Context<'a, Strategy> { + strategy: &'a Strategy, +} + +struct Reserialize { + f: Cell>, + deserializer_error: Cell>, +} + +trait ReserializeFn { + type DeserializeError: de::Error; + + fn reserialize( + self, + serializer: S, + ) -> Result>; +} + +impl serde::Serialize for Reserialize { + fn serialize(&self, serializer: S) -> Result { + let f = self.f.take().unwrap(); + match f.reserialize(serializer) { + Ok(out) => Ok(out), + Err(Error::Serialize(err)) => Err(err), + Err(Error::Deserialize(err)) => { + let res = Err(S::Error::custom(format_args!("{err}"))); + self.deserializer_error.set(Some(err)); + res + } + } + } +} + +impl Context<'_, Strategy> { + fn reserialize( + &self, + reserialize: F, + f: impl FnOnce(&Reserialize) -> T, + ) -> Result { + let reserialize = Reserialize { + f: Cell::new(Some(reserialize)), + deserializer_error: Cell::new(None), + }; + let res = f(&reserialize); + match reserialize.deserializer_error.take() { + Some(err) => Err(err), + None => Ok(res), + } + } + + fn reserialize_ty<'de, D: Deserializer<'de>, T>( + &self, + schema: &OwnedDataModelType, + deserializer: D, + f: impl FnOnce(&Reserialize>) -> T, + ) -> Result { + self.reserialize( + ReserializeTy { + context: self, + deserializer, + schema, + de: PhantomData, + }, + f, + ) + } +} + +struct ReserializeTy<'a, 'de, D, Strategy> { + context: &'a Context<'a, Strategy>, + deserializer: D, + schema: &'a OwnedDataModelType, + de: PhantomData<&'de ()>, +} + +impl<'de, D, Strategy> ReserializeFn for ReserializeTy<'_, 'de, D, Strategy> +where + D: Deserializer<'de>, + Strategy: strategy::Strategy, +{ + type DeserializeError = D::Error; + + fn reserialize( + self, + serializer: S, + ) -> Result> { + fn deserialize<'de, T, D, SerializerError>( + deserializer: D, + ) -> Result> + where + T: Deserialize<'de>, + D: Deserializer<'de>, + { + T::deserialize(deserializer).map_err(Error::Deserialize) + } + let (context, deserializer, schema) = (self.context, self.deserializer, self.schema); + match schema { + OwnedDataModelType::Schema => OwnedDataModelType::deserialize(deserializer) + .map_err(Error::Deserialize)? + .serialize(serializer), + OwnedDataModelType::Unit => serializer.serialize_unit(), + OwnedDataModelType::Bool => serializer.serialize_bool(deserialize(deserializer)?), + OwnedDataModelType::U8 => serializer.serialize_u8(deserialize(deserializer)?), + OwnedDataModelType::U16 => serializer.serialize_u16(deserialize(deserializer)?), + OwnedDataModelType::U32 => serializer.serialize_u32(deserialize(deserializer)?), + OwnedDataModelType::U64 => serializer.serialize_u64(deserialize(deserializer)?), + OwnedDataModelType::U128 => serializer.serialize_u128(deserialize(deserializer)?), + OwnedDataModelType::I8 => serializer.serialize_i8(deserialize(deserializer)?), + OwnedDataModelType::I16 => serializer.serialize_i16(deserialize(deserializer)?), + OwnedDataModelType::I32 => serializer.serialize_i32(deserialize(deserializer)?), + OwnedDataModelType::I64 => serializer.serialize_i64(deserialize(deserializer)?), + OwnedDataModelType::I128 => serializer.serialize_i128(deserialize(deserializer)?), + OwnedDataModelType::Usize => { + deserialize::(deserializer)?.serialize(serializer) + } + OwnedDataModelType::Isize => { + deserialize::(deserializer)?.serialize(serializer) + } + OwnedDataModelType::F32 => serializer.serialize_f32(deserialize(deserializer)?), + OwnedDataModelType::F64 => serializer.serialize_f64(deserialize(deserializer)?), + OwnedDataModelType::Char => serializer.serialize_char(deserialize(deserializer)?), + OwnedDataModelType::String => serializer.serialize_str(deserialize(deserializer)?), + OwnedDataModelType::ByteArray => serializer.serialize_bytes(deserialize(deserializer)?), + OwnedDataModelType::Option(inner) => deserializer + .deserialize_option(option::Visitor { + context, + serializer, + schema: inner, + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Map { key, val } => deserializer + .deserialize_map(map::Visitor { + context, + serializer, + key, + val, + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Seq(element) => deserializer + .deserialize_seq(seq::Visitor { + context, + serializer, + schemas: slice::from_ref(element), + }) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Tuple(elements) => deserializer + .deserialize_tuple( + elements.len(), + tuple::Visitor { + context, + serializer, + fields: elements, + reserializer: expecting::Tuple, + }, + ) + .map_err(Error::Deserialize)?, + OwnedDataModelType::Struct { name, data } => match data { + OwnedData::Unit => { + Strategy::reserialize_unit_struct(context, deserializer, serializer, name) + .map_err(Error::Deserialize)? + } + OwnedData::Newtype(inner) => Strategy::reserialize_newtype_struct( + context, + deserializer, + serializer, + expecting::Struct { + name, + data: expecting::data::Newtype { schema: inner }, + }, + ) + .map_err(Error::Deserialize)?, + OwnedData::Tuple(fields) => Strategy::reserialize_tuple_struct( + context, + deserializer, + serializer, + expecting::Struct { + name, + data: expecting::data::Tuple { elements: fields }, + }, + ) + .map_err(Error::Deserialize)?, + OwnedData::Struct(fields) => Strategy::reserialize_struct( + context, + deserializer, + serializer, + expecting::Struct { + name, + data: expecting::data::Struct { fields }, + }, + ) + .map_err(Error::Deserialize)?, + }, + OwnedDataModelType::Enum { name, variants } => Strategy::reserialize_enum( + context, + deserializer, + serializer, + expecting::Enum { name, variants }, + ) + .map_err(Error::Deserialize)?, + } + .map_err(Error::Serialize) + } +} + +struct Expected<'a>(fmt::Arguments<'a>); + +impl de::Expected for Expected<'_> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "{}", self.0) + } +} + +#[cfg(test)] +mod tests { + use core::fmt::Debug; + use postcard::ser_flavors::Flavor; + use postcard_schema::Schema; + use serde::de::DeserializeOwned; + use serde_json::json; + + use super::*; + + #[derive(Serialize, Deserialize, Schema, PartialEq, Debug)] + enum Enum { + Struct { a: u8, b: u8 }, + Tuple(bool, u8), + Newtype(u32), + Unit, + } + + #[derive(Serialize, Deserialize, Schema, PartialEq, Debug)] + struct Struct { + a: Option, + b: Enum, + c: u8, + } + + fn postcard_to_json(postcard: &[u8]) -> serde_json::Value { + let schema = T::SCHEMA.into(); + let leaky = lossless::reserialize_leaky( + &schema, + &mut postcard::Deserializer::from_bytes(postcard), + serde_json::value::Serializer, + ) + .unwrap(); + let lossy = lossy::reserialize_with_structs_and_enums_as_maps( + &schema, + &mut postcard::Deserializer::from_bytes(postcard), + serde_json::value::Serializer, + ) + .unwrap(); + assert_eq!(leaky, lossy); + leaky + } + + fn json_to_postcard(json: &serde_json::Value) -> Vec { + let mut serializer = postcard::Serializer { + output: postcard::ser_flavors::AllocVec::new(), + }; + lossless::reserialize_leaky(&T::SCHEMA.into(), json, &mut serializer).unwrap(); + serializer.output.finalize().unwrap() + } + + fn test_postcard_to_json_and_back(value: T) + where + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, + { + let postcard_bytes = postcard::to_allocvec(&value).unwrap(); + let json = postcard_to_json::(&postcard_bytes); + assert_eq!(json, serde_json::to_value(&value).unwrap()); + assert_eq!(T::deserialize(&json).unwrap(), value); + + let roundtripped_postcard_bytes = json_to_postcard::(&json); + assert_eq!(roundtripped_postcard_bytes, postcard_bytes); + assert_eq!( + postcard::from_bytes::(&roundtripped_postcard_bytes).unwrap(), + value + ); + } + + fn test_json_to_postcard(json: serde_json::Value) + where + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, + { + let postcard_bytes = json_to_postcard::(&json); + let from_json = T::deserialize(&json).unwrap(); + let from_postcard_bytes = postcard::from_bytes::(&postcard_bytes).unwrap(); + assert_eq!(from_postcard_bytes, from_json); + } + + fn test_json_to_postcard_and_back(json: serde_json::Value) + where + T: Schema + Serialize + DeserializeOwned + Debug + PartialEq, + { + let postcard_bytes = json_to_postcard::(&json); + + let from_json = T::deserialize(&json).unwrap(); + let from_postcard_bytes = postcard::from_bytes::(&postcard_bytes).unwrap(); + assert_eq!(from_postcard_bytes, from_json); + + let json_roundtripped = postcard_to_json::(&postcard_bytes); + assert_eq!(json_roundtripped, json); + + let from_json_roundtripped = T::deserialize(&json_roundtripped).unwrap(); + assert_eq!(from_json_roundtripped, from_json); + } + + #[test] + fn json() { + use test_postcard_to_json_and_back as test; + test(Enum::Struct { a: 5, b: 10 }); + test(Enum::Tuple(false, 15)); + test(Enum::Newtype(20)); + test(Enum::Unit); + test(Struct { + a: Some(5), + b: Enum::Struct { a: 10, b: 100 }, + c: 7, + }); + } + + #[test] + /// Make sure reserialization handles out-of-order struct fields correctly. + /// Serializers like postcard rely on struct fields being serialized in order. + fn out_of_order_fields() { + use test_json_to_postcard_and_back as test; + test::(json!({"Struct": {"b": 10, "a": 5}})); + test::(json!({"Struct": {"a": 5, "b": 0}})); + + let nested = json!({"Struct": {"b": 50, "a": 100}}); + test::(json!({"a": 5, "b": nested, "c": 10})); + test::(json!({"b": nested, "a": 5, "c": 10})); + test::(json!({"b": nested, "c": 10, "a": 5})); + test::(json!({"a": 5, "c": 10, "b": nested})); + test::(json!({"c": 10, "a": 5, "b": nested})); + } + + #[test] + fn extra_fields() { + use test_json_to_postcard as test; + test::(json!({"Struct": {"a": 5, "b": 0, "UNUSED": 10}})); + test::(json!({"a": 5, "xyz": "wat", "b": {"Newtype": 32}, "c": 10})); + } + + #[test] + #[should_panic = "missing field `b`"] + fn missing_fields() { + test_json_to_postcard::(json!({"Struct": {"a": 5}})); + } + + #[test] + #[should_panic = "invalid length 1, expected tuple variant Enum::Tuple with 2 elements"] + fn missing_tuple_fields() { + test_json_to_postcard::(json!({"Tuple": [false]})); + } + + #[test] + /// Make sure both deserializer and serializer errors are bubbled up + fn errors() { + use postcard_schema::Schema; + + assert!(matches!( + dbg!(lossless::reserialize_leaky( + &u8::SCHEMA.into(), + &mut postcard::Deserializer::from_bytes(&[]), + serde_json::value::Serializer + )), + Err(Error::Deserialize( + postcard::Error::DeserializeUnexpectedEnd + )) + )); + assert!(matches!( + dbg!(lossless::reserialize_leaky( + &u8::SCHEMA.into(), + &mut postcard::Deserializer::from_bytes(&[5]), + &mut serde_json::Serializer::new(std::io::Cursor::new([].as_mut_slice())) + )), + Err(Error::Serialize(_)) + )); + } +} diff --git a/source/postcard-dyn/src/reserialize/expecting.rs b/source/postcard-dyn/src/reserialize/expecting.rs new file mode 100644 index 00000000..0b2af9b2 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/expecting.rs @@ -0,0 +1,118 @@ +use core::{ + fmt::{self, Display}, + ops::RangeTo, +}; + +use postcard_schema::schema::owned::OwnedVariant; +use serde::de::{self, Expected}; + +pub trait Unexpected: de::Error { + fn missing_elements(len: usize, expected: &dyn Expected, expected_elements: usize) -> Self { + Self::invalid_length( + len, + &super::Expected(format_args!( + "{expected} with {expected_elements} element{}", + if expected_elements == 1 { "" } else { "s" }, + )), + ) + } + + fn unknown_variant_index(index: impl Into, expected: RangeTo) -> Self { + Self::invalid_value( + de::Unexpected::Unsigned(index.into()), + &super::Expected(format_args!("variant index 0 <= i < {}", expected.end)), + ) + } +} + +impl Unexpected for Error {} + +pub struct Tuple; + +pub struct Struct<'a, Data> { + pub name: &'a str, + pub data: Data, +} + +pub struct Enum<'name, 'schema> { + pub name: &'name str, + pub variants: &'schema [OwnedVariant], +} + +pub struct Variant<'a, Data> { + pub enum_name: &'a str, + pub variant_index: u32, + pub variant_name: &'a str, + pub data: Data, +} + +pub mod data { + use postcard_schema::schema::owned::{OwnedDataModelType, OwnedNamedField}; + + pub struct Unit; + pub struct Newtype<'a> { + pub schema: &'a OwnedDataModelType, + } + pub struct Tuple<'a> { + pub elements: &'a [OwnedDataModelType], + } + pub struct Struct<'a> { + pub fields: &'a [OwnedNamedField], + } +} + +impl Expected for Tuple { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a tuple") + } +} + +impl Expected for Struct<'_, data::Unit> { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "unit struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Newtype<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "tuple struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Tuple<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "tuple struct {}", self.name) + } +} + +impl Expected for Struct<'_, data::Struct<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "struct {}", self.name) + } +} + +impl Expected for Enum<'_, '_> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "enum {}", self.name) + } +} + +impl Expected for Variant<'_, data::Struct<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "struct variant {}::{}", + self.enum_name, self.variant_name + ) + } +} + +impl Expected for Variant<'_, data::Tuple<'_>> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "tuple variant {}::{}", + self.enum_name, self.variant_name + ) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless.rs b/source/postcard-dyn/src/reserialize/lossless.rs new file mode 100644 index 00000000..9a7bf491 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless.rs @@ -0,0 +1,249 @@ +//! Lossless reserialization. +//! +//! As noted [above](super), lossless serialization is only possible by compromising elsewhere. +//! This module provides implementations with different compromises: +//! - [`reserialize_leaky()`] leaks memory for each unique struct/enum/variant/field name + +use core::{cell::RefCell, fmt, str}; + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{ + de::{self, Deserializer}, + ser::Serializer, +}; + +use crate::Error; + +use super::{ + expecting, + strategy::{self, Strategy as _}, + Context, +}; + +mod interned; +use interned::Interned; + +mod enums; +mod structs; +mod tuples; + +/// Reserialize [`postcard`]-encoded data losslessly, **leaking memory**. +/// +/// In order to serialize structs and enums losslessly, this **allocates and leaks each unique +/// struct/enum/variant/field name, and the list of field names for each struct**. +/// +/// # Examples +/// +/// ``` +/// # fn main() -> Result<(), Box> { +/// # use postcard::ser_flavors::Flavor; +/// use postcard_schema::Schema; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize, Schema, PartialEq, Debug)] +/// enum Foo { +/// Bar { a: u8, b: u8 }, +/// } +/// +/// let value = Foo::Bar { a: 5, b: 10 }; +/// let bytes = postcard::to_allocvec(&value)?; +/// let mut serializer = postcard::Serializer { +/// output: postcard::ser_flavors::StdVec::new(), +/// }; +/// postcard_dyn::reserialize::lossless::reserialize_leaky( +/// &Foo::SCHEMA.into(), +/// &mut postcard::Deserializer::from_bytes(&bytes), +/// &mut serializer, +/// )?; +/// let out = serializer.output.finalize()?; +/// let deserialized: Foo = postcard::from_bytes(&out)?; +/// assert_eq!(deserialized, value); +/// # Ok(()) +/// # } +/// ``` +pub fn reserialize_leaky<'de, D, S>( + schema: &OwnedDataModelType, + deserializer: D, + serializer: S, +) -> Result> +where + D: Deserializer<'de>, + S: Serializer, +{ + Strategy.reserialize(schema, deserializer, serializer) +} + +/// Reserialize structs and enums losslessly, **leaking memory**. +struct Strategy; + +impl Strategy { + fn with_interned(&self, f: impl FnOnce(&mut Interned) -> T) -> T { + thread_local! { + static INTERNED: RefCell = RefCell::new(Default::default()); + } + INTERNED.with_borrow_mut(f) + } +} + +impl strategy::Strategy for Strategy { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + name: &str, + ) -> Result, D::Error> { + struct Visitor { + serializer: S, + expecting: expecting::Struct<'static, expecting::data::Unit>, + } + + impl de::Visitor<'_> for Visitor { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_unit(self) -> Result { + Ok(self.serializer.serialize_unit_struct(self.expecting.name)) + } + } + + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(name)); + deserializer.deserialize_unit_struct( + name, + Visitor { + serializer, + expecting: expecting::Struct { + name, + data: expecting::data::Unit, + }, + }, + ) + } + + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error> { + struct Visitor<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: S, + expecting: expecting::Struct<'static, expecting::data::Newtype<'a>>, + } + + impl<'de, S: Serializer> de::Visitor<'de> for Visitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + self.context + .reserialize_ty(self.expecting.data.schema, deserializer, |inner| { + self.serializer + .serialize_newtype_struct(self.expecting.name, inner) + }) + } + } + + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(expecting.name)); + deserializer.deserialize_newtype_struct( + name, + Visitor { + context, + serializer, + expecting: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) + } + + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Tuple>, + ) -> Result, D::Error> { + let name = context + .strategy + .with_interned(|interned| interned.intern_identifier(expecting.name)); + deserializer.deserialize_tuple_struct( + name, + expecting.data.elements.len(), + super::tuple::Visitor { + context, + serializer, + fields: expecting.data.elements, + reserializer: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) + } + + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Struct>, + ) -> Result, D::Error> { + let fields = expecting.data.fields; + let (name, field_names) = context.strategy.with_interned(|interned| { + let name = interned.intern_identifier(expecting.name); + let field_names = interned.intern_slice(fields.iter().map(|f| f.name.as_ref())); + (name, field_names) + }); + deserializer.deserialize_struct( + name, + field_names, + structs::Visitor { + context, + serializer, + fields, + field_names, + reserializer: expecting::Struct { + name, + data: expecting.data, + }, + }, + ) + } + + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error> { + let variants = expecting.variants; + let (name, variant_names) = context.strategy.with_interned(|interned| { + let name = interned.intern_identifier(expecting.name); + let variant_names = interned.intern_slice(variants.iter().map(|v| v.name.as_ref())); + (name, variant_names) + }); + deserializer.deserialize_enum( + name, + variant_names, + enums::Visitor { + context, + serializer, + expecting: expecting::Enum { name, variants }, + variant_names, + }, + ) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/enums.rs b/source/postcard-dyn/src/reserialize/lossless/enums.rs new file mode 100644 index 00000000..3b3137bb --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/enums.rs @@ -0,0 +1,172 @@ +use core::{fmt, str}; + +use postcard_schema::schema::owned::{OwnedData, OwnedDataModelType, OwnedVariant}; +use serde::{ + de::{self, DeserializeSeed, EnumAccess, VariantAccess}, + Deserializer, Serializer, +}; + +use crate::reserialize::{ + self, + expecting::{self, Unexpected}, + Context, +}; + +use super::Strategy; + +pub struct Visitor<'a, S> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub expecting: expecting::Enum<'static, 'a>, + pub variant_names: &'static [&'static str], +} + +impl<'de, S: Serializer> de::Visitor<'de> for Visitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_enum>(self, data: A) -> Result { + let ((variant_index, variant_name, variant), deserializer) = + data.variant_seed(VariantVisitor { + variants: self.expecting.variants, + variant_names: self.variant_names, + })?; + match variant { + OwnedData::Unit => { + deserializer.unit_variant()?; + Ok(self.serializer.serialize_unit_variant( + self.expecting.name, + variant_index, + variant_name, + )) + } + OwnedData::Newtype(inner) => deserializer.newtype_variant_seed(NewtypeVariantSeed { + context: self.context, + schema: inner, + serializer: self.serializer, + location: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Newtype { schema: inner }, + }, + }), + OwnedData::Tuple(fields) => deserializer.tuple_variant( + fields.len(), + reserialize::tuple::Visitor { + context: self.context, + serializer: self.serializer, + fields, + reserializer: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Tuple { elements: fields }, + }, + }, + ), + OwnedData::Struct(fields) => { + let field_names = self.context.strategy.with_interned(|interned| { + interned.intern_slice(fields.iter().map(|f| f.name.as_ref())) + }); + deserializer.struct_variant( + field_names, + super::structs::Visitor { + context: self.context, + serializer: self.serializer, + fields, + field_names, + reserializer: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name, + data: expecting::data::Struct { fields }, + }, + }, + ) + } + } + } +} + +struct VariantVisitor<'a> { + variants: &'a [OwnedVariant], + variant_names: &'static [&'static str], +} + +impl<'a, 'de> DeserializeSeed<'de> for VariantVisitor<'a> { + type Value = (u32, &'static str, &'a OwnedData); + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_identifier(self) + } +} + +impl<'a> de::Visitor<'_> for VariantVisitor<'a> { + type Value = (u32, &'static str, &'a OwnedData); + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "variant identifier") + } + + fn visit_u64(self, value: u64) -> Result { + let err = || E::unknown_variant_index(value, ..self.variants.len()); + let index = u32::try_from(value).map_err(|_| err())?; + let (name, schema) = { + let idx = usize::try_from(value).map_err(|_| err())?; + (self.variant_names.get(idx)) + .zip(self.variants.get(idx)) + .ok_or_else(err)? + }; + Ok((index, name, &schema.data)) + } + + fn visit_str(self, value: &str) -> Result { + self.find(value.as_bytes()) + .ok_or_else(|| E::unknown_variant(value, self.variant_names)) + } + + fn visit_bytes(self, value: &[u8]) -> Result { + self.find(value).ok_or_else(|| match str::from_utf8(value) { + Ok(value) => E::unknown_variant(value, self.variant_names), + Err(_) => E::invalid_value(de::Unexpected::Bytes(value), &self), + }) + } +} + +impl<'a> VariantVisitor<'a> { + fn find(&self, variant: &[u8]) -> Option<(u32, &'static str, &'a OwnedData)> { + (self.variant_names.iter()) + .zip(self.variants) + .enumerate() + .find_map(|(index, (&name, schema))| { + (name.as_bytes() == variant).then_some((index as u32, name, &schema.data)) + }) + } +} + +struct NewtypeVariantSeed<'a, S> { + context: &'a Context<'a, Strategy>, + schema: &'a OwnedDataModelType, + serializer: S, + location: expecting::Variant<'static, expecting::data::Newtype<'a>>, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for NewtypeVariantSeed<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |inner| { + self.serializer.serialize_newtype_variant( + self.location.enum_name, + self.location.variant_index, + self.location.variant_name, + inner, + ) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/interned.rs b/source/postcard-dyn/src/reserialize/lossless/interned.rs new file mode 100644 index 00000000..4689b7fe --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/interned.rs @@ -0,0 +1,80 @@ +use core::hash::{Hash, Hasher}; + +use hashbrown::HashSet; + +#[derive(Default)] +pub struct Interned { + strings: HashSet<&'static str>, + slices: HashSet, +} + +impl Interned { + pub fn intern_identifier(&mut self, s: &str) -> &'static str { + Self::intern_str(&mut self.strings, s) + } + + fn intern_str(strings: &mut HashSet<&'static str>, s: &str) -> &'static str { + strings.get_or_insert_with(s, |s| String::leak(s.to_string())) + } + + pub fn intern_slice<'a>( + &mut self, + strings: impl IntoIterator, + ) -> &'static [&'static str] { + let Slice(slice) = self + .slices + .get_or_insert_with(&Iter(strings.into_iter()), |elements| { + let strings = elements.0.clone(); + let interned = strings.map(|s| Self::intern_str(&mut self.strings, s)); + Slice(Box::leak(interned.collect())) + }); + slice + } +} + +#[derive(PartialEq, Eq)] +struct Slice(&'static [&'static str]); +struct Iter<'a, I: Iterator>(I); + +impl Hash for Slice { + fn hash(&self, state: &mut H) { + for s in self.0 { + s.hash(state) + } + } +} + +impl<'a, I: Iterator + Clone> Hash for Iter<'a, I> { + fn hash(&self, state: &mut H) { + for s in self.0.clone() { + s.hash(state) + } + } +} + +impl<'a, I> hashbrown::Equivalent for Iter<'a, I> +where + I: Iterator + Clone, +{ + fn equivalent(&self, slice: &Slice) -> bool { + self.0.clone().eq(slice.0.iter().copied()) + } +} + +#[cfg(test)] +mod tests { + use super::Interned; + + #[test] + fn basic() { + let mut interned = Interned::default(); + + assert_eq!(interned.intern_identifier("hello"), "hello"); + + let slices: &[&[&str]] = &[&[], &["foo"], &["foo", "bar"]]; + for &slice in slices { + assert_eq!(interned.intern_slice(slice.iter().copied()), slice); + } + assert_eq!(interned.slices.len(), slices.len()); + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/structs.rs b/source/postcard-dyn/src/reserialize/lossless/structs.rs new file mode 100644 index 00000000..6e584662 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/structs.rs @@ -0,0 +1,286 @@ +use core::fmt; +use std::collections::HashMap; + +use postcard_schema::schema::owned::{OwnedDataModelType, OwnedNamedField}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, Error as _, MapAccess, SeqAccess}, + ser::{self, Error as _, Serialize, SerializeStruct, Serializer}, +}; + +use crate::reserialize::{ + expecting::{self, Unexpected}, + Context, +}; + +use super::Strategy; + +pub struct Visitor<'a, S, Strategy, Reserializer> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub reserializer: Reserializer, + pub fields: &'a [OwnedNamedField], + pub field_names: &'static [&'static str], +} + +trait Reserializer: de::Expected { + type SerializeFields: SerializeStruct; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result; +} + +struct SerializeStructVariant(T); + +impl SerializeStruct for SerializeStructVariant { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(key, value) + } + + fn end(self) -> Result { + self.0.end() + } +} + +impl Reserializer for expecting::Variant<'static, expecting::data::Struct<'_>> { + type SerializeFields = SerializeStructVariant; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_struct_variant(self.enum_name, self.variant_index, self.variant_name, len) + .map(SerializeStructVariant) + } +} + +impl Reserializer for expecting::Struct<'static, expecting::data::Struct<'_>> { + type SerializeFields = S::SerializeStruct; + + fn reserialize_struct( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer.serialize_struct(self.name, len) + } +} + +impl<'de, S, Reserializer> de::Visitor<'de> for Visitor<'_, S, Strategy, Reserializer> +where + S: Serializer, + Reserializer: self::Reserializer, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.reserializer, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut serializer = match self + .reserializer + .reserialize_struct(self.serializer, self.field_names.len()) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let fields = (self.field_names.iter()) + .zip(self.fields) + .map(|(&name, field)| (name, &field.ty)); + for (idx, (name, schema)) in fields.enumerate() { + let seed = FieldSeed { + context: self.context, + serializer: &mut serializer, + name, + schema, + }; + let res = seq.next_element_seed(seed)?.ok_or_else(|| { + A::Error::missing_elements(idx, &self.reserializer, self.fields.len()) + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } + + fn visit_map>(self, mut map: A) -> Result { + let mut serializer = match self + .reserializer + .reserialize_struct(self.serializer, self.field_names.len()) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + + let key = FieldVisitor { + fields: self.fields, + field_names: self.field_names, + }; + let mut remaining_fields = self.field_names.iter().peekable(); + let mut out_of_order_fields = None; + while let Some(field) = map.next_key_seed(&key)? { + match field { + Err(Ignored) => { + // This only works for self-describing formats, but it should only + // be self-describing formats that deserialize to ignored fields. + let de::IgnoredAny = map.next_value::()?; + } + Ok((name, schema)) if remaining_fields.next_if_eq(&&name).is_some() => { + let res = map.next_value_seed(FieldSeed { + context: self.context, + serializer: &mut serializer, + name, + schema, + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok((name, schema)) => { + // Fields were deserialized out-of-order. Serializers assume fields are + // serialized in-order, so buffer up the out of order fields then serialize + // them in order. + let out_of_order = out_of_order_fields.get_or_insert_with(|| { + OutOfOrderFields(HashMap::with_capacity(remaining_fields.len())) + }); + let res: Result<(), serde_content::Error> = map.next_value_seed(FieldSeed { + context: self.context, + serializer: out_of_order, + name, + schema, + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(S::Error::custom(err))), + } + } + } + } + let mut out_of_order = out_of_order_fields + .map(|OutOfOrderFields(out_of_order)| out_of_order) + .unwrap_or(HashMap::new()); + for field in remaining_fields { + match out_of_order.remove(field) { + Some(value) => match serializer.serialize_field(field, &value) { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + }, + None => return Err(A::Error::missing_field(field)), + } + } + for field in self.field_names { + if out_of_order.contains_key(field) { + return Err(A::Error::duplicate_field(field)); + } + } + Ok(serializer.end()) + } +} + +struct FieldSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + name: &'static str, + schema: &'a OwnedDataModelType, +} + +impl<'de, S: SerializeStruct> DeserializeSeed<'de> for FieldSeed<'_, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |value| { + self.serializer.serialize_field(self.name, value) + }) + } +} + +struct FieldVisitor<'a> { + fields: &'a [OwnedNamedField], + field_names: &'static [&'static str], +} + +struct Ignored; + +impl<'a, 'de> DeserializeSeed<'de> for &FieldVisitor<'a> { + type Value = Result<(&'static str, &'a OwnedDataModelType), Ignored>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_identifier(self) + } +} + +impl<'a> de::Visitor<'_> for &FieldVisitor<'a> { + type Value = Result<(&'static str, &'a OwnedDataModelType), Ignored>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "field identifier") + } + + fn visit_u64(self, value: u64) -> Result { + Ok((usize::try_from(value).ok()) + .and_then(|idx| { + let (&name, schema) = (self.field_names.get(idx)).zip(self.fields.get(idx))?; + Some((name, &schema.ty)) + }) + .ok_or(Ignored)) + } + + fn visit_str(self, value: &str) -> Result { + Ok(self.find(value.as_bytes())) + } + + fn visit_bytes(self, value: &[u8]) -> Result { + Ok(self.find(value)) + } +} + +impl<'a> FieldVisitor<'a> { + fn find(&self, field: &[u8]) -> Result<(&'static str, &'a OwnedDataModelType), Ignored> { + self.field_names + .iter() + .zip(self.fields) + .find_map(|(&name, schema)| (name.as_bytes() == field).then_some((name, &schema.ty))) + .ok_or(Ignored) + } +} + +#[derive(Debug)] +struct OutOfOrderFields<'a>(HashMap<&'a str, serde_content::Value<'a>>); + +impl<'a> SerializeStruct for OutOfOrderFields<'a> { + type Ok = HashMap<&'a str, serde_content::Value<'a>>; + type Error = serde_content::Error; + + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> Result<(), Self::Error> { + let value = value.serialize(serde_content::Serializer::new())?; + debug_assert!(self.0.len() < self.0.capacity()); + self.0.insert(key, value); + Ok(()) + } + + fn end(self) -> Result { + Ok(self.0) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossless/tuples.rs b/source/postcard-dyn/src/reserialize/lossless/tuples.rs new file mode 100644 index 00000000..8e1a07d5 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossless/tuples.rs @@ -0,0 +1,66 @@ +use serde::ser::{self, Serialize, SerializeTuple, Serializer}; + +use crate::reserialize::{expecting, tuple::Reserializer}; + +pub struct SerializeTupleVariant(T); +pub struct SerializeTupleStruct(T); + +impl Reserializer for expecting::Variant<'static, expecting::data::Tuple<'_>> { + type SerializeTuple = SerializeTupleVariant; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_tuple_variant(self.enum_name, self.variant_index, self.variant_name, len) + .map(SerializeTupleVariant) + } +} + +impl Reserializer for expecting::Struct<'static, expecting::data::Tuple<'_>> { + type SerializeTuple = SerializeTupleStruct; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer + .serialize_tuple_struct(self.name, len) + .map(SerializeTupleStruct) + } +} + +impl SerializeTuple for SerializeTupleVariant { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end() + } +} + +impl SerializeTuple for SerializeTupleStruct { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + self.0.serialize_field(value) + } + + fn end(self) -> Result { + self.0.end() + } +} diff --git a/source/postcard-dyn/src/reserialize/lossy.rs b/source/postcard-dyn/src/reserialize/lossy.rs new file mode 100644 index 00000000..adfddfec --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossy.rs @@ -0,0 +1,155 @@ +//! Lossy reserialization. +//! +//! As noted [above](super), lossless serialization is only possible by compromising elsewhere. +//! This module provides lossy implementations with compromises in the serialization format: +//! - [`reserialize_with_structs_and_enums_as_maps()`] reserializes structs and enums as maps +//! instead of actual structs and enums. + +use postcard::de_flavors::Flavor; +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{ + de::{Deserialize, Deserializer}, + ser::{Serialize, Serializer}, +}; + +use crate::{reserialize, Error}; + +use super::{ + expecting, + strategy::{self, Strategy as _}, + Context, +}; + +mod enums; +mod structs; + +/// Reserialize [`postcard`]-encoded data, transforming structs and enums into maps. +/// +/// - Structs are transformed into maps with field names (as strings) for keys and field values +/// for values. +/// - Unit enum variants are transformed into the variant name as a string. +/// - Data-carrying (i.e., non-unit) enum variants are transformed into a single-element map +/// from the variant name (as a string) to the variant's data. +/// +/// This mirrors [`serde_json`]'s behavior. +/// +/// # Examples +/// +/// ``` +/// # fn main() -> Result<(), Box> { +/// # use postcard::ser_flavors::Flavor; +/// # use std::collections::BTreeMap; +/// use postcard_schema::Schema; +/// use serde::Serialize; +/// +/// #[derive(Serialize, Schema)] +/// enum Foo { +/// Bar { a: u8, b: u8 }, +/// } +/// +/// let bytes = postcard::to_allocvec(&Foo::Bar { a: 5, b: 10 })?; +/// let mut serializer = postcard::Serializer { +/// output: postcard::ser_flavors::StdVec::new(), +/// }; +/// postcard_dyn::reserialize::lossy::reserialize_with_structs_and_enums_as_maps( +/// &Foo::SCHEMA.into(), +/// &mut postcard::Deserializer::from_bytes(&bytes), +/// &mut serializer, +/// )?; +/// let out = serializer.output.finalize()?; +/// let deserialized: BTreeMap<&str, BTreeMap<&str, u8>> = postcard::from_bytes(&out)?; +/// assert_eq!( +/// format!("{deserialized:?}"), +/// r#"{"Bar": {"a": 5, "b": 10}}"# +/// ); +/// # Ok(()) +/// # } +/// ``` +pub fn reserialize_with_structs_and_enums_as_maps<'de, F, S>( + schema: &OwnedDataModelType, + deserializer: &mut postcard::Deserializer<'de, F>, + serializer: S, +) -> Result> +where + F: Flavor<'de>, + S: Serializer, +{ + Strategy.reserialize(schema, deserializer, serializer) +} + +/// Reserialize structs and enums as maps similar to [`serde_json`]. +struct Strategy; + +impl strategy::Strategy for Strategy { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + _context: &Context<'_, Self>, + deserializer: D, + serializer: S, + _name: &str, + ) -> Result, D::Error> { + <()>::deserialize(deserializer)?; + Ok(serializer.serialize_unit()) + } + + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error> { + context.reserialize_ty(expecting.data.schema, deserializer, |inner| { + inner.serialize(serializer) + }) + } + + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Tuple<'_>>, + ) -> Result, D::Error> { + deserializer.deserialize_tuple( + expecting.data.elements.len(), + reserialize::tuple::Visitor { + context, + serializer, + fields: expecting.data.elements, + reserializer: expecting::Tuple, + }, + ) + } + + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Struct<'_>>, + ) -> Result, D::Error> { + deserializer.deserialize_tuple( + expecting.data.fields.len(), + reserialize::tuple::Visitor { + context, + serializer, + fields: expecting.data.fields.iter().map(|f| &f.ty), + reserializer: structs::ReserializeStructAsMap { expecting }, + }, + ) + } + + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error> { + // Postcard encodes enums as (index, value) + deserializer.deserialize_tuple( + 2, + enums::Visitor { + serializer, + context, + expecting, + }, + ) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossy/enums.rs b/source/postcard-dyn/src/reserialize/lossy/enums.rs new file mode 100644 index 00000000..890cf4c0 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossy/enums.rs @@ -0,0 +1,293 @@ +use core::{fmt, marker::PhantomData}; + +use postcard_schema::schema::owned::{OwnedData, OwnedDataModelType}; +use serde::{ + de::{self, DeserializeSeed, Deserializer, SeqAccess}, + ser::{Error as _, SerializeMap, SerializeTuple, Serializer}, +}; + +use crate::{ + reserialize::{ + expecting::{self, Unexpected}, + Context, ReserializeFn, + }, + Error, +}; + +use super::Strategy; + +pub struct Visitor<'a, S> { + pub serializer: S, + pub context: &'a Context<'a, Strategy>, + pub expecting: expecting::Enum<'a, 'a>, +} + +impl<'de, S: Serializer> de::Visitor<'de> for Visitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let variant_index: u32 = seq.next_element()?.unwrap(); + let variant = (usize::try_from(variant_index).ok()) + .and_then(|v| self.expecting.variants.get(v)) + .ok_or_else(|| { + A::Error::unknown_variant_index(variant_index, ..self.expecting.variants.len()) + })?; + + let err = || S::Error::custom("missing variant data"); + Ok(match &variant.data { + OwnedData::Unit => self.serializer.serialize_str(&variant.name), + OwnedData::Newtype(inner) => seq + .next_element_seed(NewtypeVariantSeed { + serializer: self.serializer, + context: self.context, + variant: &variant.name, + inner, + })? + .ok_or_else(err) + .and_then(|res| res), + OwnedData::Tuple(fields) => seq + .next_element_seed(TupleVariantVisitor { + serializer: self.serializer, + context: self.context, + expecting: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name: &variant.name, + data: expecting::data::Tuple { elements: fields }, + }, + })? + .ok_or_else(err) + .and_then(|res| res), + OwnedData::Struct(fields) => seq + .next_element_seed(StructVariantVisitor { + serializer: self.serializer, + context: self.context, + expecting: expecting::Variant { + enum_name: self.expecting.name, + variant_index, + variant_name: &variant.name, + data: expecting::data::Struct { fields }, + }, + })? + .ok_or_else(err) + .and_then(|res| res), + }) + } +} + +struct NewtypeVariantSeed<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + variant: &'a str, + inner: &'a OwnedDataModelType, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for NewtypeVariantSeed<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let res = self + .context + .reserialize_ty(self.inner, deserializer, |value| { + serializer.serialize_entry(self.variant, value) + })?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + Ok(serializer.end()) + } +} + +struct TupleVariantVisitor<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + expecting: expecting::Variant<'a, expecting::data::Tuple<'a>>, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for TupleVariantVisitor<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_tuple(self.expecting.data.elements.len(), self) + } +} + +impl<'de, S: Serializer> de::Visitor<'de> for TupleVariantVisitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, seq: A) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + self.context.reserialize( + ReserializeTupleVariant { + context: self.context, + seq, + expecting: &self.expecting, + de: PhantomData, + }, + |data| { + serializer.serialize_entry(self.expecting.variant_name, data)?; + serializer.end() + }, + ) + } +} +struct ReserializeTupleVariant<'de, 'a, A> { + context: &'a Context<'a, Strategy>, + seq: A, + expecting: &'a expecting::Variant<'a, expecting::data::Tuple<'a>>, + de: PhantomData<&'de ()>, +} + +struct ElementSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedDataModelType, +} + +impl<'de, S: SerializeTuple> DeserializeSeed<'de> for ElementSeed<'_, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +} + +impl<'de, A: SeqAccess<'de>> ReserializeFn for ReserializeTupleVariant<'de, '_, A> { + type DeserializeError = A::Error; + + fn reserialize( + mut self, + serializer: S, + ) -> Result> { + let fields = self.expecting.data.elements; + let mut serializer = serializer + .serialize_tuple(fields.len()) + .map_err(Error::Serialize)?; + for (idx, field) in fields.iter().enumerate() { + self.seq + .next_element_seed(ElementSeed { + context: self.context, + serializer: &mut serializer, + schema: field, + }) + .map_err(Error::Deserialize)? + .ok_or_else(|| A::Error::missing_elements(idx, self.expecting, fields.len())) + .map_err(Error::Deserialize)? + .map_err(Error::Serialize)?; + } + serializer.end().map_err(Error::Serialize) + } +} + +struct StructVariantVisitor<'a, S> { + serializer: S, + context: &'a Context<'a, Strategy>, + expecting: expecting::Variant<'a, expecting::data::Struct<'a>>, +} + +impl<'de, S: Serializer> DeserializeSeed<'de> for StructVariantVisitor<'_, S> { + type Value = Result; + + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_tuple(self.expecting.data.fields.len(), self) + } +} + +impl<'de, S: Serializer> de::Visitor<'de> for StructVariantVisitor<'_, S> { + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } + + fn visit_seq>(self, seq: A) -> Result { + let mut serializer = match self.serializer.serialize_map(Some(1)) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + self.context.reserialize( + ReserializeStructVariant { + context: self.context, + seq, + expecting: &self.expecting, + de: PhantomData, + }, + |data| { + serializer.serialize_entry(self.expecting.variant_name, data)?; + serializer.end() + }, + ) + } +} + +struct ReserializeStructVariant<'a, 'de, A> { + context: &'a Context<'a, Strategy>, + seq: A, + expecting: &'a expecting::Variant<'a, expecting::data::Struct<'a>>, + de: PhantomData<&'de ()>, +} + +struct FieldSeed<'a, S> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + key: &'a str, + schema: &'a OwnedDataModelType, +} + +impl<'de, S: SerializeMap> DeserializeSeed<'de> for FieldSeed<'_, S> { + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |value| { + self.serializer.serialize_entry(self.key, value) + }) + } +} + +impl<'de, A: SeqAccess<'de>> ReserializeFn for ReserializeStructVariant<'_, 'de, A> { + type DeserializeError = A::Error; + + fn reserialize( + mut self, + serializer: S, + ) -> Result> { + let fields = self.expecting.data.fields; + let mut serializer = serializer + .serialize_map(Some(fields.len())) + .map_err(Error::Serialize)?; + for (idx, field) in fields.iter().enumerate() { + self.seq + .next_element_seed(FieldSeed { + context: self.context, + serializer: &mut serializer, + key: &field.name, + schema: &field.ty, + }) + .map_err(Error::Deserialize)? + .ok_or_else(|| A::Error::missing_elements(idx, self.expecting, fields.len())) + .map_err(Error::Deserialize)? + .map_err(Error::Serialize)?; + } + serializer.end().map_err(Error::Serialize) + } +} diff --git a/source/postcard-dyn/src/reserialize/lossy/structs.rs b/source/postcard-dyn/src/reserialize/lossy/structs.rs new file mode 100644 index 00000000..575fbc11 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/lossy/structs.rs @@ -0,0 +1,55 @@ +use core::slice; + +use postcard_schema::schema::owned::OwnedNamedField; +use serde::{ + de, + ser::{Serialize, SerializeMap, SerializeTuple, Serializer}, +}; + +use crate::reserialize::{self, expecting}; + +pub struct ReserializeStructAsMap<'a> { + pub expecting: expecting::Struct<'a, expecting::data::Struct<'a>>, +} + +impl de::Expected for ReserializeStructAsMap<'_> { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + de::Expected::fmt(&self.expecting, formatter) + } +} + +impl<'a, S: Serializer> reserialize::tuple::Reserializer for ReserializeStructAsMap<'a> { + type SerializeTuple = SerializeFieldsAsMapEntries<'a, S::SerializeMap>; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result::Error> { + let serializer = serializer.serialize_map(Some(len))?; + let fields = self.expecting.data.fields.iter(); + Ok(SerializeFieldsAsMapEntries { serializer, fields }) + } +} + +pub struct SerializeFieldsAsMapEntries<'a, S> { + serializer: S, + fields: slice::Iter<'a, OwnedNamedField>, +} + +impl SerializeTuple for SerializeFieldsAsMapEntries<'_, S> { + type Ok = S::Ok; + type Error = S::Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let field = self.fields.next().unwrap(); + self.serializer.serialize_entry(&field.name, value) + } + + fn end(self) -> Result { + self.serializer.end() + } +} diff --git a/source/postcard-dyn/src/reserialize/map.rs b/source/postcard-dyn/src/reserialize/map.rs new file mode 100644 index 00000000..674fa7e3 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/map.rs @@ -0,0 +1,98 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{ + de::{self, DeserializeSeed, MapAccess}, + ser::SerializeMap, + Deserializer, Serializer, +}; + +use super::Context; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub key: &'a OwnedDataModelType, + pub val: &'a OwnedDataModelType, +} + +impl<'de, S, Strategy> de::Visitor<'de> for Visitor<'_, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a map") + } + + fn visit_map>(self, mut map: A) -> Result { + let mut serializer = match self.serializer.serialize_map(map.size_hint()) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + while let Some(res) = map.next_key_seed(KeySeed { + context: self.context, + serializer: &mut serializer, + schema: self.key, + })? { + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + let seed = ValueSeed { + context: self.context, + serializer: &mut serializer, + schema: self.val, + }; + match map.next_value_seed(seed)? { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct KeySeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedDataModelType, +} + +impl<'de, S, Strategy> DeserializeSeed<'de> for KeySeed<'_, S, Strategy> +where + S: SerializeMap, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |key| { + self.serializer.serialize_key(key) + }) + } +} + +struct ValueSeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schema: &'a OwnedDataModelType, +} + +impl<'de, S, Strategy> DeserializeSeed<'de> for ValueSeed<'_, S, Strategy> +where + S: SerializeMap, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |val| { + self.serializer.serialize_value(val) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/option.rs b/source/postcard-dyn/src/reserialize/option.rs new file mode 100644 index 00000000..879539a4 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/option.rs @@ -0,0 +1,35 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{de, Deserializer, Serializer}; + +use super::Context; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub schema: &'a OwnedDataModelType, +} + +impl<'de, S, Strategy> de::Visitor<'de> for Visitor<'_, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("option") + } + + fn visit_none(self) -> Result { + Ok(self.serializer.serialize_none()) + } + + fn visit_some>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.schema, deserializer, |inner| { + self.serializer.serialize_some(inner) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/seq.rs b/source/postcard-dyn/src/reserialize/seq.rs new file mode 100644 index 00000000..0b516b35 --- /dev/null +++ b/source/postcard-dyn/src/reserialize/seq.rs @@ -0,0 +1,80 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{ + de::{self, DeserializeSeed, Error as _, SeqAccess}, + ser::SerializeSeq, + Deserializer, Serializer, +}; + +use super::{Context, Expected}; + +pub struct Visitor<'a, S, Strategy> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub schemas: &'a [OwnedDataModelType], +} + +impl<'de, S, Strategy> de::Visitor<'de> for Visitor<'_, S, Strategy> +where + S: Serializer, + Strategy: super::Strategy, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut serializer = match self.serializer.serialize_seq(seq.size_hint()) { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + let mut seed = ElementSeed { + context: self.context, + serializer: &mut serializer, + schemas: self.schemas, + idx: 0, + }; + while let Some(res) = seq.next_element_seed(&mut seed)? { + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct ElementSeed<'a, S: 'a, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + schemas: &'a [OwnedDataModelType], + idx: usize, +} + +impl<'de, S, Strategy> DeserializeSeed<'de> for &mut ElementSeed<'_, S, Strategy> +where + S: SerializeSeq, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let schemas = self.schemas; + let schema = schemas.get(self.idx).ok_or_else(|| { + D::Error::invalid_length( + self.idx + 1, + &Expected(format_args!("sequence of length {}", schemas.len())), + ) + })?; + self.context + .reserialize_ty(schema, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +} diff --git a/source/postcard-dyn/src/reserialize/strategy.rs b/source/postcard-dyn/src/reserialize/strategy.rs new file mode 100644 index 00000000..3133057a --- /dev/null +++ b/source/postcard-dyn/src/reserialize/strategy.rs @@ -0,0 +1,62 @@ +//! How to reserialize structs and enums to work around [`Deserializer`] and [`Serializer`]'s +//! `&'static str` requirements. + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{Deserializer, Serialize, Serializer}; + +use crate::Error; + +use super::{expecting, Context}; + +/// How to reserialize structs and enums to work around [`Deserializer`] and [`Serializer`]'s +/// `&'static str` requirements. +pub(super) trait Strategy: Sized { + fn reserialize_unit_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + name: &str, + ) -> Result, D::Error>; + + fn reserialize_newtype_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Newtype>, + ) -> Result, D::Error>; + + fn reserialize_tuple_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Tuple>, + ) -> Result, D::Error>; + + fn reserialize_struct<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Struct<'_, expecting::data::Struct>, + ) -> Result, D::Error>; + + fn reserialize_enum<'de, D: Deserializer<'de>, S: Serializer>( + context: &Context<'_, Self>, + deserializer: D, + serializer: S, + expecting: expecting::Enum<'_, '_>, + ) -> Result, D::Error>; + + fn reserialize<'de, D: Deserializer<'de>, S: Serializer>( + &self, + schema: &OwnedDataModelType, + deserializer: D, + serializer: S, + ) -> Result> { + let context = Context { strategy: self }; + match context.reserialize_ty(schema, deserializer, |value| value.serialize(serializer)) { + Ok(Ok(out)) => Ok(out), + Ok(Err(err)) => Err(Error::Serialize(err)), + Err(err) => Err(Error::Deserialize(err)), + } + } +} diff --git a/source/postcard-dyn/src/reserialize/tuple.rs b/source/postcard-dyn/src/reserialize/tuple.rs new file mode 100644 index 00000000..a61e99ee --- /dev/null +++ b/source/postcard-dyn/src/reserialize/tuple.rs @@ -0,0 +1,105 @@ +use core::fmt; + +use postcard_schema::schema::owned::OwnedDataModelType; +use serde::{ + de::{self, DeserializeSeed, SeqAccess}, + ser::SerializeTuple, + Deserializer, Serializer, +}; + +use super::{ + expecting::{self, Unexpected}, + Context, +}; + +pub struct Visitor<'a, S, Strategy, Fields, Reserializer> { + pub context: &'a Context<'a, Strategy>, + pub serializer: S, + pub fields: Fields, + pub reserializer: Reserializer, +} + +pub trait Reserializer: de::Expected { + type SerializeTuple: SerializeTuple; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result; +} + +impl Reserializer for expecting::Tuple { + type SerializeTuple = S::SerializeTuple; + + fn reserialize_tuple( + &self, + serializer: S, + len: usize, + ) -> Result { + serializer.serialize_tuple(len) + } +} + +impl<'de, 'schema, S, Strategy, Fields, Reserializer> de::Visitor<'de> + for Visitor<'_, S, Strategy, Fields, Reserializer> +where + S: Serializer, + Strategy: super::Strategy, + Fields: IntoIterator, + Reserializer: self::Reserializer, +{ + type Value = Result; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + de::Expected::fmt(&self.reserializer, formatter) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let fields = self.fields.into_iter(); + let num_fields = fields.len(); + let mut serializer = match self + .reserializer + .reserialize_tuple(self.serializer, num_fields) + { + Ok(serializer) => serializer, + Err(err) => return Ok(Err(err)), + }; + for (idx, field) in fields.enumerate() { + let seed = ElementSeed { + context: self.context, + serializer: &mut serializer, + field, + }; + let res = seq + .next_element_seed(seed)? + .ok_or_else(|| A::Error::missing_elements(idx, &self.reserializer, num_fields))?; + match res { + Ok(()) => {} + Err(err) => return Ok(Err(err)), + } + } + Ok(serializer.end()) + } +} + +struct ElementSeed<'a, S, Strategy> { + context: &'a Context<'a, Strategy>, + serializer: &'a mut S, + field: &'a OwnedDataModelType, +} + +impl<'de, S, Strategy> DeserializeSeed<'de> for ElementSeed<'_, S, Strategy> +where + S: SerializeTuple, + Strategy: super::Strategy, +{ + type Value = Result<(), S::Error>; + + fn deserialize>(self, deserializer: D) -> Result { + self.context + .reserialize_ty(self.field, deserializer, |element| { + self.serializer.serialize_element(element) + }) + } +}