diff --git a/services/src/api/model/datatypes.rs b/services/src/api/model/datatypes.rs index 1ab5be794..b23f7251c 100644 --- a/services/src/api/model/datatypes.rs +++ b/services/src/api/model/datatypes.rs @@ -7,6 +7,8 @@ use geoengine_datatypes::primitives::{ use geoengine_macros::type_tag; use ordered_float::NotNan; use postgres_types::{FromSql, ToSql}; +use serde::de::Error as SerdeError; +use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor}; use snafu::ResultExt; use std::{ @@ -898,73 +900,71 @@ impl From for geoengine_datatypes::primitives::Continuous #[type_tag(value = "classification")] #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, ToSchema)] -#[serde( - try_from = "SerializableClassificationMeasurement", - into = "SerializableClassificationMeasurement" -)] pub struct ClassificationMeasurement { pub measurement: String, - pub classes: HashMap, + // use a BTreeMap to preserve the order of the keys + #[serde(serialize_with = "serialize_classes")] + #[serde(deserialize_with = "deserialize_classes")] + pub classes: BTreeMap, } -impl From - for ClassificationMeasurement +fn serialize_classes(classes: &BTreeMap, serializer: S) -> Result +where + S: Serializer, { - fn from(value: geoengine_datatypes::primitives::ClassificationMeasurement) -> Self { - Self { - r#type: Default::default(), - measurement: value.measurement, - classes: value.classes, - } + let mut map = serializer.serialize_map(Some(classes.len()))?; + for (k, v) in classes { + map.serialize_entry(&k.to_string(), v)?; } + map.end() } -impl From - for geoengine_datatypes::primitives::ClassificationMeasurement +fn deserialize_classes<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, { - fn from(value: ClassificationMeasurement) -> Self { - Self { - measurement: value.measurement, - classes: value.classes, - } + let map = BTreeMap::::deserialize(deserializer)?; + let mut classes = BTreeMap::new(); + for (k, v) in map { + classes.insert( + k.parse::() + .map_err(|e| D::Error::custom(format!("Failed to parse key as u8: {e}")))?, + v, + ); } + Ok(classes) } -/// A type that is solely for serde's serializability. -/// You cannot serialize floats as JSON map keys. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct SerializableClassificationMeasurement { - pub measurement: String, - // use a BTreeMap to preserve the order of the keys - pub classes: BTreeMap, -} - -impl From for SerializableClassificationMeasurement { - fn from(measurement: ClassificationMeasurement) -> Self { +impl From + for ClassificationMeasurement +{ + fn from(value: geoengine_datatypes::primitives::ClassificationMeasurement) -> Self { let mut classes = BTreeMap::new(); - for (k, v) in measurement.classes { - classes.insert(k.to_string(), v); + for (k, v) in value.classes { + classes.insert(k, v); } + Self { - measurement: measurement.measurement, + r#type: Default::default(), + measurement: value.measurement, classes, } } } -impl TryFrom for ClassificationMeasurement { - type Error = ::Err; - - fn try_from(measurement: SerializableClassificationMeasurement) -> Result { +impl From + for geoengine_datatypes::primitives::ClassificationMeasurement +{ + fn from(measurement: ClassificationMeasurement) -> Self { let mut classes = HashMap::with_capacity(measurement.classes.len()); for (k, v) in measurement.classes { - classes.insert(k.parse::()?, v); + classes.insert(k, v); } - Ok(Self { - r#type: Default::default(), + + Self { measurement: measurement.measurement, classes, - }) + } } } @@ -2189,3 +2189,61 @@ impl<'a> FromSql<'a> for CacheTtlSeconds { ::accepts(ty) } } + +#[cfg(test)] +mod tests { + use crate::api::model::datatypes::ClassificationMeasurement; + use crate::error::Error; + use std::collections::BTreeMap; + + #[test] + fn it_serializes_classification_measurement() -> Result<(), Error> { + let measurement = ClassificationMeasurement { + r#type: Default::default(), + measurement: "Test".to_string(), + classes: BTreeMap::::from([ + (0, "Class 0".to_string()), + (1, "Class 1".to_string()), + ]), + }; + + let serialized = serde_json::to_string(&measurement)?; + + assert_eq!( + serialized, + r#"{"type":"classification","measurement":"Test","classes":{"0":"Class 0","1":"Class 1"}}"# + ); + Ok(()) + } + + #[test] + fn it_deserializes_classification_measurement() -> Result<(), Error> { + let measurement = ClassificationMeasurement { + r#type: Default::default(), + measurement: "Test".to_string(), + classes: BTreeMap::::from([ + (0, "Class 0".to_string()), + (1, "Class 1".to_string()), + ]), + }; + + let serialized = r#"{"type":"classification","measurement":"Test","classes":{"0":"Class 0","1":"Class 1"}}"#; + let deserialized: ClassificationMeasurement = serde_json::from_str(serialized)?; + + assert_eq!(measurement, deserialized); + Ok(()) + } + + #[test] + fn it_throws_error_on_deserializing_non_integer_classification_measurement_class_value() { + let serialized = + r#"{"type":"classification","measurement":"Test","classes":{"Zero":"Class 0"}}"#; + let deserialized = serde_json::from_str::(serialized); + + assert!(deserialized.is_err()); + assert_eq!( + deserialized.unwrap_err().to_string(), + "Failed to parse key as u8: invalid digit found in string at line 1 column 75" + ); + } +}