Skip to content

Commit 8c18b14

Browse files
fix(services): classification measurement serialization (#1055)
* Fix classification measurement serialization * Add tests
1 parent 4d3e935 commit 8c18b14

File tree

1 file changed

+100
-42
lines changed

1 file changed

+100
-42
lines changed

services/src/api/model/datatypes.rs

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use geoengine_datatypes::primitives::{
77
use geoengine_macros::type_tag;
88
use ordered_float::NotNan;
99
use postgres_types::{FromSql, ToSql};
10+
use serde::de::Error as SerdeError;
11+
use serde::ser::SerializeMap;
1012
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Visitor};
1113
use snafu::ResultExt;
1214
use std::{
@@ -898,73 +900,71 @@ impl From<ContinuousMeasurement> for geoengine_datatypes::primitives::Continuous
898900

899901
#[type_tag(value = "classification")]
900902
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, ToSchema)]
901-
#[serde(
902-
try_from = "SerializableClassificationMeasurement",
903-
into = "SerializableClassificationMeasurement"
904-
)]
905903
pub struct ClassificationMeasurement {
906904
pub measurement: String,
907-
pub classes: HashMap<u8, String>,
905+
// use a BTreeMap to preserve the order of the keys
906+
#[serde(serialize_with = "serialize_classes")]
907+
#[serde(deserialize_with = "deserialize_classes")]
908+
pub classes: BTreeMap<u8, String>,
908909
}
909910

910-
impl From<geoengine_datatypes::primitives::ClassificationMeasurement>
911-
for ClassificationMeasurement
911+
fn serialize_classes<S>(classes: &BTreeMap<u8, String>, serializer: S) -> Result<S::Ok, S::Error>
912+
where
913+
S: Serializer,
912914
{
913-
fn from(value: geoengine_datatypes::primitives::ClassificationMeasurement) -> Self {
914-
Self {
915-
r#type: Default::default(),
916-
measurement: value.measurement,
917-
classes: value.classes,
918-
}
915+
let mut map = serializer.serialize_map(Some(classes.len()))?;
916+
for (k, v) in classes {
917+
map.serialize_entry(&k.to_string(), v)?;
919918
}
919+
map.end()
920920
}
921921

922-
impl From<ClassificationMeasurement>
923-
for geoengine_datatypes::primitives::ClassificationMeasurement
922+
fn deserialize_classes<'de, D>(deserializer: D) -> Result<BTreeMap<u8, String>, D::Error>
923+
where
924+
D: Deserializer<'de>,
924925
{
925-
fn from(value: ClassificationMeasurement) -> Self {
926-
Self {
927-
measurement: value.measurement,
928-
classes: value.classes,
929-
}
926+
let map = BTreeMap::<String, String>::deserialize(deserializer)?;
927+
let mut classes = BTreeMap::new();
928+
for (k, v) in map {
929+
classes.insert(
930+
k.parse::<u8>()
931+
.map_err(|e| D::Error::custom(format!("Failed to parse key as u8: {e}")))?,
932+
v,
933+
);
930934
}
935+
Ok(classes)
931936
}
932937

933-
/// A type that is solely for serde's serializability.
934-
/// You cannot serialize floats as JSON map keys.
935-
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
936-
pub struct SerializableClassificationMeasurement {
937-
pub measurement: String,
938-
// use a BTreeMap to preserve the order of the keys
939-
pub classes: BTreeMap<String, String>,
940-
}
941-
942-
impl From<ClassificationMeasurement> for SerializableClassificationMeasurement {
943-
fn from(measurement: ClassificationMeasurement) -> Self {
938+
impl From<geoengine_datatypes::primitives::ClassificationMeasurement>
939+
for ClassificationMeasurement
940+
{
941+
fn from(value: geoengine_datatypes::primitives::ClassificationMeasurement) -> Self {
944942
let mut classes = BTreeMap::new();
945-
for (k, v) in measurement.classes {
946-
classes.insert(k.to_string(), v);
943+
for (k, v) in value.classes {
944+
classes.insert(k, v);
947945
}
946+
948947
Self {
949-
measurement: measurement.measurement,
948+
r#type: Default::default(),
949+
measurement: value.measurement,
950950
classes,
951951
}
952952
}
953953
}
954954

955-
impl TryFrom<SerializableClassificationMeasurement> for ClassificationMeasurement {
956-
type Error = <u8 as FromStr>::Err;
957-
958-
fn try_from(measurement: SerializableClassificationMeasurement) -> Result<Self, Self::Error> {
955+
impl From<ClassificationMeasurement>
956+
for geoengine_datatypes::primitives::ClassificationMeasurement
957+
{
958+
fn from(measurement: ClassificationMeasurement) -> Self {
959959
let mut classes = HashMap::with_capacity(measurement.classes.len());
960960
for (k, v) in measurement.classes {
961-
classes.insert(k.parse::<u8>()?, v);
961+
classes.insert(k, v);
962962
}
963-
Ok(Self {
964-
r#type: Default::default(),
963+
964+
Self {
965965
measurement: measurement.measurement,
966966
classes,
967-
})
967+
}
968968
}
969969
}
970970

@@ -2189,3 +2189,61 @@ impl<'a> FromSql<'a> for CacheTtlSeconds {
21892189
<i32 as FromSql>::accepts(ty)
21902190
}
21912191
}
2192+
2193+
#[cfg(test)]
2194+
mod tests {
2195+
use crate::api::model::datatypes::ClassificationMeasurement;
2196+
use crate::error::Error;
2197+
use std::collections::BTreeMap;
2198+
2199+
#[test]
2200+
fn it_serializes_classification_measurement() -> Result<(), Error> {
2201+
let measurement = ClassificationMeasurement {
2202+
r#type: Default::default(),
2203+
measurement: "Test".to_string(),
2204+
classes: BTreeMap::<u8, String>::from([
2205+
(0, "Class 0".to_string()),
2206+
(1, "Class 1".to_string()),
2207+
]),
2208+
};
2209+
2210+
let serialized = serde_json::to_string(&measurement)?;
2211+
2212+
assert_eq!(
2213+
serialized,
2214+
r#"{"type":"classification","measurement":"Test","classes":{"0":"Class 0","1":"Class 1"}}"#
2215+
);
2216+
Ok(())
2217+
}
2218+
2219+
#[test]
2220+
fn it_deserializes_classification_measurement() -> Result<(), Error> {
2221+
let measurement = ClassificationMeasurement {
2222+
r#type: Default::default(),
2223+
measurement: "Test".to_string(),
2224+
classes: BTreeMap::<u8, String>::from([
2225+
(0, "Class 0".to_string()),
2226+
(1, "Class 1".to_string()),
2227+
]),
2228+
};
2229+
2230+
let serialized = r#"{"type":"classification","measurement":"Test","classes":{"0":"Class 0","1":"Class 1"}}"#;
2231+
let deserialized: ClassificationMeasurement = serde_json::from_str(serialized)?;
2232+
2233+
assert_eq!(measurement, deserialized);
2234+
Ok(())
2235+
}
2236+
2237+
#[test]
2238+
fn it_throws_error_on_deserializing_non_integer_classification_measurement_class_value() {
2239+
let serialized =
2240+
r#"{"type":"classification","measurement":"Test","classes":{"Zero":"Class 0"}}"#;
2241+
let deserialized = serde_json::from_str::<ClassificationMeasurement>(serialized);
2242+
2243+
assert!(deserialized.is_err());
2244+
assert_eq!(
2245+
deserialized.unwrap_err().to_string(),
2246+
"Failed to parse key as u8: invalid digit found in string at line 1 column 75"
2247+
);
2248+
}
2249+
}

0 commit comments

Comments
 (0)