Skip to content

Commit 66a178c

Browse files
authored
Preserve np.ndarray in casting, do not cast to list.
1 parent 871dafd commit 66a178c

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None):
102102
return bounding_box.parse(value)
103103
elif not isinstance(data_type, type):
104104
raise ValueError(f"No special case for type {data_type}.")
105+
elif isinstance(value, np.ndarray) and issubclass(data_type, np.generic):
106+
return value.astype(data_type)
105107
elif isinstance(value, list) or isinstance(value, np.ndarray):
106108
return [_cast_value(ctx=ctx, value=v, data_type=data_type) for v in value]
107109
elif data_type == bytes and not isinstance(value, bytes):

python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,21 @@ def test_cast_value(conforms_to, value, data_type, expected):
5151
assert field._cast_value(ctx, value, data_type) == expected
5252

5353

54+
@parametrize_conforms_to()
55+
@pytest.mark.parametrize(
56+
["value", "data_type", "expected"],
57+
[
58+
[np.array([1, 2, 3]), DataType.INTEGER, np.array([1, 2, 3])],
59+
[np.array([1, 2, 3]), DataType.FLOAT32, np.array([1.0, 2.0, 3.0])],
60+
],
61+
)
62+
def test_cast_value_ndarray():
63+
ctx = Context(conforms_to=conforms_to)
64+
cast_value = field._cast_value(ctx, value, data_type)
65+
assert cast_value == expected
66+
assert cast_value.dtype == expected.dtype
67+
68+
5469
@parametrize_conforms_to()
5570
@pytest.mark.parametrize(
5671
["value", "data_type"],

0 commit comments

Comments
 (0)