Skip to content

Commit 992cdad

Browse files
authored
Merge pull request #72 from davidhewitt/deserialize-set
serialize set and frozenset
2 parents 2595552 + 2666d63 commit 992cdad

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
### Fixed
2424
- Fix overflow error attempting to depythonize `u64` values greater than `i64::MAX` to types like `serde_json::Value`
25+
- Fix deserializing `set` and `frozenset` into Rust homogeneous containers
2526

2627
## 0.21.1 - 2024-04-02
2728

src/de.rs

+79-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use pyo3::{types::*, Bound};
22
use serde::de::{self, DeserializeOwned, IntoDeserializer};
33
use serde::Deserialize;
44

5-
use crate::error::{PythonizeError, Result};
5+
use crate::error::{ErrorImpl, PythonizeError, Result};
66

77
/// Attempt to convert a Python object to an instance of `T`
88
pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result<T>
@@ -44,6 +44,19 @@ impl<'a, 'py> Depythonizer<'a, 'py> {
4444
}
4545
}
4646

47+
fn set_access(&self) -> Result<PySetAsSequence<'py>> {
48+
match self.input.downcast::<PySet>() {
49+
Ok(set) => Ok(PySetAsSequence::from_set(&set)),
50+
Err(e) => {
51+
if let Ok(f) = self.input.downcast::<PyFrozenSet>() {
52+
Ok(PySetAsSequence::from_frozenset(&f))
53+
} else {
54+
Err(e.into())
55+
}
56+
}
57+
}
58+
}
59+
4760
fn dict_access(&self) -> Result<PyMappingAccess<'py>> {
4861
PyMappingAccess::new(self.input.downcast()?)
4962
}
@@ -122,10 +135,9 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
122135
self.deserialize_bytes(visitor)
123136
} else if obj.is_instance_of::<PyFloat>() {
124137
self.deserialize_f64(visitor)
125-
} else if obj.is_instance_of::<PyFrozenSet>()
126-
|| obj.is_instance_of::<PySet>()
127-
|| obj.downcast::<PySequence>().is_ok()
128-
{
138+
} else if obj.is_instance_of::<PyFrozenSet>() || obj.is_instance_of::<PySet>() {
139+
self.deserialize_seq(visitor)
140+
} else if obj.downcast::<PySequence>().is_ok() {
129141
self.deserialize_tuple(obj.len()?, visitor)
130142
} else if obj.downcast::<PyMapping>().is_ok() {
131143
self.deserialize_map(visitor)
@@ -238,7 +250,18 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
238250
where
239251
V: de::Visitor<'de>,
240252
{
241-
visitor.visit_seq(self.sequence_access(None)?)
253+
match self.sequence_access(None) {
254+
Ok(seq) => visitor.visit_seq(seq),
255+
Err(e) => {
256+
// we allow sets to be deserialized as sequences, so try that
257+
if matches!(*e.inner, ErrorImpl::UnexpectedType(_)) {
258+
if let Ok(set) = self.set_access() {
259+
return visitor.visit_seq(set);
260+
}
261+
}
262+
Err(e)
263+
}
264+
}
242265
}
243266

244267
fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
@@ -357,6 +380,40 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> {
357380
}
358381
}
359382

383+
struct PySetAsSequence<'py> {
384+
iter: Bound<'py, PyIterator>,
385+
}
386+
387+
impl<'py> PySetAsSequence<'py> {
388+
fn from_set(set: &Bound<'py, PySet>) -> Self {
389+
Self {
390+
iter: PyIterator::from_bound_object(&set).expect("set is always iterable"),
391+
}
392+
}
393+
394+
fn from_frozenset(set: &Bound<'py, PyFrozenSet>) -> Self {
395+
Self {
396+
iter: PyIterator::from_bound_object(&set).expect("frozenset is always iterable"),
397+
}
398+
}
399+
}
400+
401+
impl<'de> de::SeqAccess<'de> for PySetAsSequence<'_> {
402+
type Error = PythonizeError;
403+
404+
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
405+
where
406+
T: de::DeserializeSeed<'de>,
407+
{
408+
match self.iter.next() {
409+
Some(item) => seed
410+
.deserialize(&mut Depythonizer::from_object(&item?))
411+
.map(Some),
412+
None => Ok(None),
413+
}
414+
}
415+
}
416+
360417
struct PyMappingAccess<'py> {
361418
keys: Bound<'py, PySequence>,
362419
values: Bound<'py, PySequence>,
@@ -606,6 +663,22 @@ mod test {
606663
test_de(code, &expected, &expected_json);
607664
}
608665

666+
#[test]
667+
fn test_vec_from_pyset() {
668+
let expected = vec!["foo".to_string()];
669+
let expected_json = json!(["foo"]);
670+
let code = "{'foo'}";
671+
test_de(code, &expected, &expected_json);
672+
}
673+
674+
#[test]
675+
fn test_vec_from_pyfrozenset() {
676+
let expected = vec!["foo".to_string()];
677+
let expected_json = json!(["foo"]);
678+
let code = "frozenset({'foo'})";
679+
test_de(code, &expected, &expected_json);
680+
}
681+
609682
#[test]
610683
fn test_vec() {
611684
let expected = vec![3, 2, 1];

src/error.rs

+9
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ impl PythonizeError {
3232
}
3333
}
3434

35+
pub(crate) fn unexpected_type<T>(t: T) -> Self
36+
where
37+
T: ToString,
38+
{
39+
Self {
40+
inner: Box::new(ErrorImpl::UnexpectedType(t.to_string())),
41+
}
42+
}
43+
3544
pub(crate) fn dict_key_not_string() -> Self {
3645
Self {
3746
inner: Box::new(ErrorImpl::DictKeyNotString),

0 commit comments

Comments
 (0)