Skip to content

Commit 9692706

Browse files
committed
[Bug fix] Add tests on StructArray shape (#17)
1 parent b7689ed commit 9692706

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tests/test_structarray.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ class StructArrayTestCase(unittest.TestCase):
88
def setUp(self):
99
# Prepare some example data for testing
1010
self.struct1 = {'a': 1, 'b': 2}
11-
self.struct2 = {'a': 3, 'c': 4}
11+
self.struct2 = {'a': 2, 'c': 4}
12+
self.struct3 = {'a': 3, 'c': 5}
13+
self.struct4 = {'a': 4, 'b': 1}
1214
self.array1d = [self.struct1, self.struct2]
13-
self.array2d = np.array([[self.struct1, self.struct2]])
15+
self.array2d = np.array([[self.struct1, self.struct2], [self.struct3, self.struct4]])
1416

1517
def test_initialization_with_structs(self):
1618
sa = StructArray(self.array1d)
@@ -35,8 +37,8 @@ def test_as_matlab_object(self):
3537
sa = StructArray(self.array2d)
3638
objdict = sa._as_matlab_object()
3739
self.assertEqual(objdict['type__'], 'structarray')
38-
self.assertEqual(objdict['size__'].tolist(), list(self.array2d.shape))
39-
self.assertEqual(np.asarray(objdict['data__']).shape, self.array2d.shape)
40+
self.assertEqual(objdict['size__'].reshape(-1).tolist(), list(self.array2d.shape))
41+
self.assertEqual(len(objdict['data__']), self.array2d.size)
4042

4143
def test_from_matlab_object(self):
4244
sa = StructArray(self.array2d)
@@ -50,6 +52,15 @@ def test_with_struct(self):
5052
objdict = sa._as_matlab_object()
5153
reconstructed_sa = StructArray._from_matlab_object(objdict)
5254

55+
def test_flat_shape(self):
56+
sa = StructArray(self.array2d)
57+
objdict = sa._as_matlab_object()
58+
self.assertEqual(len(objdict['data__']), 4)
59+
self.assertEqual(objdict['data__'][0]['a'], self.struct1['a'])
60+
self.assertEqual(objdict['data__'][1]['a'], self.struct3['a'])
61+
self.assertEqual(objdict['data__'][2]['a'], self.struct2['a'])
62+
self.assertEqual(objdict['data__'][3]['a'], self.struct4['a'])
63+
5364
def test_repr(self):
5465
sa = StructArray(self.array1d)
5566
repr_str = repr(sa)

0 commit comments

Comments
 (0)