@@ -8,9 +8,11 @@ class StructArrayTestCase(unittest.TestCase):
88 def setUp (self ):
99 # Prepare some example data for testing
1010 self .struct1 = {'a' : 1 , 'b' : 2 }
11- self .struct2 = {'a' : 3 , 'c' : 4 }
11+ self .struct2 = {'a' : 2 , 'c' : 4 }
12+ self .struct3 = {'a' : 3 , 'c' : 5 }
13+ self .struct4 = {'a' : 4 , 'b' : 1 }
1214 self .array1d = [self .struct1 , self .struct2 ]
13- self .array2d = np .array ([[self .struct1 , self .struct2 ]])
15+ self .array2d = np .array ([[self .struct1 , self .struct2 ], [ self . struct3 , self . struct4 ] ])
1416
1517 def test_initialization_with_structs (self ):
1618 sa = StructArray (self .array1d )
@@ -35,8 +37,8 @@ def test_as_matlab_object(self):
3537 sa = StructArray (self .array2d )
3638 objdict = sa ._as_matlab_object ()
3739 self .assertEqual (objdict ['type__' ], 'structarray' )
38- self .assertEqual (objdict ['size__' ].tolist (), list (self .array2d .shape ))
39- self .assertEqual (np . asarray (objdict ['data__' ]). shape , self .array2d .shape )
40+ self .assertEqual (objdict ['size__' ].reshape ( - 1 ). tolist (), list (self .array2d .shape ))
41+ self .assertEqual (len (objdict ['data__' ]), self .array2d .size )
4042
4143 def test_from_matlab_object (self ):
4244 sa = StructArray (self .array2d )
@@ -50,6 +52,15 @@ def test_with_struct(self):
5052 objdict = sa ._as_matlab_object ()
5153 reconstructed_sa = StructArray ._from_matlab_object (objdict )
5254
55+ def test_flat_shape (self ):
56+ sa = StructArray (self .array2d )
57+ objdict = sa ._as_matlab_object ()
58+ self .assertEqual (len (objdict ['data__' ]), 4 )
59+ self .assertEqual (objdict ['data__' ][0 ]['a' ], self .struct1 ['a' ])
60+ self .assertEqual (objdict ['data__' ][1 ]['a' ], self .struct3 ['a' ])
61+ self .assertEqual (objdict ['data__' ][2 ]['a' ], self .struct2 ['a' ])
62+ self .assertEqual (objdict ['data__' ][3 ]['a' ], self .struct4 ['a' ])
63+
5364 def test_repr (self ):
5465 sa = StructArray (self .array1d )
5566 repr_str = repr (sa )
0 commit comments