@@ -2,7 +2,7 @@ use pyo3::{types::*, Bound};
2
2
use serde:: de:: { self , DeserializeOwned , IntoDeserializer } ;
3
3
use serde:: Deserialize ;
4
4
5
- use crate :: error:: { PythonizeError , Result } ;
5
+ use crate :: error:: { ErrorImpl , PythonizeError , Result } ;
6
6
7
7
/// Attempt to convert a Python object to an instance of `T`
8
8
pub fn depythonize < ' a , ' py , T > ( obj : & ' a Bound < ' py , PyAny > ) -> Result < T >
@@ -44,6 +44,19 @@ impl<'a, 'py> Depythonizer<'a, 'py> {
44
44
}
45
45
}
46
46
47
+ fn set_access ( & self ) -> Result < PySetAsSequence < ' py > > {
48
+ match self . input . downcast :: < PySet > ( ) {
49
+ Ok ( set) => Ok ( PySetAsSequence :: from_set ( & set) ) ,
50
+ Err ( e) => {
51
+ if let Ok ( f) = self . input . downcast :: < PyFrozenSet > ( ) {
52
+ Ok ( PySetAsSequence :: from_frozenset ( & f) )
53
+ } else {
54
+ Err ( e. into ( ) )
55
+ }
56
+ }
57
+ }
58
+ }
59
+
47
60
fn dict_access ( & self ) -> Result < PyMappingAccess < ' py > > {
48
61
PyMappingAccess :: new ( self . input . downcast ( ) ?)
49
62
}
@@ -122,10 +135,9 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
122
135
self . deserialize_bytes ( visitor)
123
136
} else if obj. is_instance_of :: < PyFloat > ( ) {
124
137
self . deserialize_f64 ( visitor)
125
- } else if obj. is_instance_of :: < PyFrozenSet > ( )
126
- || obj. is_instance_of :: < PySet > ( )
127
- || obj. downcast :: < PySequence > ( ) . is_ok ( )
128
- {
138
+ } else if obj. is_instance_of :: < PyFrozenSet > ( ) || obj. is_instance_of :: < PySet > ( ) {
139
+ self . deserialize_seq ( visitor)
140
+ } else if obj. downcast :: < PySequence > ( ) . is_ok ( ) {
129
141
self . deserialize_tuple ( obj. len ( ) ?, visitor)
130
142
} else if obj. downcast :: < PyMapping > ( ) . is_ok ( ) {
131
143
self . deserialize_map ( visitor)
@@ -238,7 +250,18 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> {
238
250
where
239
251
V : de:: Visitor < ' de > ,
240
252
{
241
- visitor. visit_seq ( self . sequence_access ( None ) ?)
253
+ match self . sequence_access ( None ) {
254
+ Ok ( seq) => visitor. visit_seq ( seq) ,
255
+ Err ( e) => {
256
+ // we allow sets to be deserialized as sequences, so try that
257
+ if matches ! ( * e. inner, ErrorImpl :: UnexpectedType ( _) ) {
258
+ if let Ok ( set) = self . set_access ( ) {
259
+ return visitor. visit_seq ( set) ;
260
+ }
261
+ }
262
+ Err ( e)
263
+ }
264
+ }
242
265
}
243
266
244
267
fn deserialize_tuple < V > ( self , len : usize , visitor : V ) -> Result < V :: Value >
@@ -357,6 +380,40 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> {
357
380
}
358
381
}
359
382
383
+ struct PySetAsSequence < ' py > {
384
+ iter : Bound < ' py , PyIterator > ,
385
+ }
386
+
387
+ impl < ' py > PySetAsSequence < ' py > {
388
+ fn from_set ( set : & Bound < ' py , PySet > ) -> Self {
389
+ Self {
390
+ iter : PyIterator :: from_bound_object ( & set) . expect ( "set is always iterable" ) ,
391
+ }
392
+ }
393
+
394
+ fn from_frozenset ( set : & Bound < ' py , PyFrozenSet > ) -> Self {
395
+ Self {
396
+ iter : PyIterator :: from_bound_object ( & set) . expect ( "frozenset is always iterable" ) ,
397
+ }
398
+ }
399
+ }
400
+
401
+ impl < ' de > de:: SeqAccess < ' de > for PySetAsSequence < ' _ > {
402
+ type Error = PythonizeError ;
403
+
404
+ fn next_element_seed < T > ( & mut self , seed : T ) -> Result < Option < T :: Value > >
405
+ where
406
+ T : de:: DeserializeSeed < ' de > ,
407
+ {
408
+ match self . iter . next ( ) {
409
+ Some ( item) => seed
410
+ . deserialize ( & mut Depythonizer :: from_object ( & item?) )
411
+ . map ( Some ) ,
412
+ None => Ok ( None ) ,
413
+ }
414
+ }
415
+ }
416
+
360
417
struct PyMappingAccess < ' py > {
361
418
keys : Bound < ' py , PySequence > ,
362
419
values : Bound < ' py , PySequence > ,
@@ -606,6 +663,22 @@ mod test {
606
663
test_de ( code, & expected, & expected_json) ;
607
664
}
608
665
666
+ #[ test]
667
+ fn test_vec_from_pyset ( ) {
668
+ let expected = vec ! [ "foo" . to_string( ) ] ;
669
+ let expected_json = json ! ( [ "foo" ] ) ;
670
+ let code = "{'foo'}" ;
671
+ test_de ( code, & expected, & expected_json) ;
672
+ }
673
+
674
+ #[ test]
675
+ fn test_vec_from_pyfrozenset ( ) {
676
+ let expected = vec ! [ "foo" . to_string( ) ] ;
677
+ let expected_json = json ! ( [ "foo" ] ) ;
678
+ let code = "frozenset({'foo'})" ;
679
+ test_de ( code, & expected, & expected_json) ;
680
+ }
681
+
609
682
#[ test]
610
683
fn test_vec ( ) {
611
684
let expected = vec ! [ 3 , 2 , 1 ] ;
0 commit comments