Skip to content

Commit 50420e1

Browse files
committed
[fix] use append_fields in dict_to_structured and also fixed append_fields for arbitrary shaped data
1 parent 8129d6e commit 50420e1

1 file changed

Lines changed: 43 additions & 27 deletions

File tree

src/oqd_dataschema/utils.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,40 @@
2626
########################################################################################
2727

2828

29-
def _unstructured_to_structured_helper(data, dtype, counter=0):
29+
def _unstructured_to_structured_helper(data, dtype):
3030
for n, (k, (v, _)) in enumerate(dtype.fields.items()):
3131
if isinstance(v.fields, MappingProxyType):
32-
x = _unstructured_to_structured_helper(data, v, counter=counter)
33-
counter += len(rfn.flatten_descr(v))
32+
x = _unstructured_to_structured_helper(data, v)
3433

3534
else:
36-
x = data[..., counter].astype(type(v))
37-
x = x.astype(
35+
x = data.pop(0).astype(type(v))
36+
37+
if n == 0:
38+
new_data = x.astype(
3839
np.dtype(
3940
[
4041
(k, x.dtype),
4142
]
4243
)
4344
)
44-
counter += 1
45-
46-
if n == 0:
47-
new_data = x
4845
else:
49-
new_data = rfn.append_fields(new_data, k, x, usemask=False)
46+
if new_data.shape != x.shape:
47+
raise ValueError(
48+
f"Incompatible shape, expected {new_data.shape} but got {x.shape}."
49+
)
5050

51-
return new_data
51+
new_data = rfn.append_fields(
52+
new_data.flatten(), k, x.flatten(), usemask=False
53+
).reshape(x.shape)
54+
55+
return new_data.view(np.recarray)
5256

5357

5458
def unstructured_to_structured(data, dtype):
59+
data = list(np.moveaxis(data, -1, 0))
60+
5561
leaves = len(rfn.flatten_descr(dtype))
56-
if data.shape[-1] != leaves:
62+
if len(data) != leaves:
5763
raise ValueError(
5864
f"Incompatible shape, last dimension of data ({data.shape[-1]}) must match number of leaves in structured dtype ({leaves})."
5965
)
@@ -77,28 +83,38 @@ def _dtype_from_dict(data):
7783
return np.dtype(np_dtype)
7884

7985

80-
def _dict_to_structured_helper(new_data, data, dtype):
81-
for k, (v, _) in dtype.fields.items():
86+
def _dict_to_structured_helper(data, dtype):
87+
for n, (k, (v, _)) in enumerate(dtype.fields.items()):
8288
if isinstance(v.fields, MappingProxyType):
83-
_dict_to_structured_helper(new_data[k], data[k], v)
84-
continue
89+
x = _dict_to_structured_helper(data[k], v)
90+
else:
91+
x = data.pop(0).astype(type(v))
8592

86-
new_data[k] = data[k]
87-
return new_data
93+
if n == 0:
94+
new_data = x.astype(
95+
np.dtype(
96+
[
97+
(k, x.dtype),
98+
]
99+
)
100+
)
101+
else:
102+
if new_data.shape != x.shape:
103+
raise ValueError(
104+
f"Incompatible shape, expected {new_data.shape} but got {x.shape}."
105+
)
106+
107+
new_data = rfn.append_fields(
108+
new_data.flatten(), k, x.flatten(), usemask=False
109+
).reshape(x.shape)
110+
111+
return new_data.view(np.recarray)
88112

89113

90114
def dict_to_structured(data):
91115
data_dtype = _dtype_from_dict(data)
92116

93-
example_data = data
94-
key = rfn.get_names(data_dtype)[0]
95-
while isinstance(key, tuple):
96-
example_data = example_data[key[0]]
97-
key = key[1][0]
98-
example_data = example_data[key]
99-
100-
new_data = np.empty(example_data.shape, dtype=data_dtype)
101-
_dict_to_structured_helper(new_data, data, dtype=data_dtype)
117+
new_data = _dict_to_structured_helper(data, dtype=data_dtype)
102118

103119
return new_data
104120

0 commit comments

Comments
 (0)