diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py index 056899b6..d5a70933 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py @@ -102,6 +102,8 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None): return bounding_box.parse(value) elif not isinstance(data_type, type): raise ValueError(f"No special case for type {data_type}.") + elif isinstance(value, np.ndarray) and issubclass(data_type, np.generic): + return value.astype(data_type) elif isinstance(value, list) or isinstance(value, np.ndarray): return [_cast_value(ctx=ctx, value=v, data_type=data_type) for v in value] elif data_type == bytes and not isinstance(value, bytes): diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py index 4299193b..f71d8c3b 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py @@ -51,6 +51,21 @@ def test_cast_value(conforms_to, value, data_type, expected): assert field._cast_value(ctx, value, data_type) == expected +@parametrize_conforms_to() +@pytest.mark.parametrize( + ["value", "data_type", "expected"], + [ + [np.array([1, 2, 3]), DataType.INTEGER, np.array([1, 2, 3])], + [np.array([1, 2, 3]), DataType.FLOAT32, np.array([1.0, 2.0, 3.0])], + ], +) +def test_cast_value_ndarray(): + ctx = Context(conforms_to=conforms_to) + cast_value = field._cast_value(ctx, value, data_type) + assert cast_value == expected + assert cast_value.dtype == expected.dtype + + @parametrize_conforms_to() @pytest.mark.parametrize( ["value", "data_type"],