Skip to content

Commit

Permalink
Set object codec for object arrays (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Sep 30, 2024
1 parent d24f83b commit 87db8ba
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby"
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not object_dtype"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
36 changes: 26 additions & 10 deletions cubed/storage/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@
from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store


def open_backend_array(
store: T_Store,
mode: str,
*,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
path: Optional[str] = None,
**kwargs,
):
def backend_storage_name():
# get storage name from top-level config
# e.g. set globally with CUBED_STORAGE_NAME=tensorstore
storage_name = config.get("storage_name", None)
Expand All @@ -26,10 +17,35 @@ def open_backend_array(
else:
storage_name = "zarr-python"

return storage_name


def open_backend_array(
store: T_Store,
mode: str,
*,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
path: Optional[str] = None,
**kwargs,
):
storage_name = backend_storage_name()

if storage_name == "zarr-python":
from cubed.storage.backends.zarr_python import open_zarr_array

open_func = open_zarr_array

# set object codec if needed
import numpy as np

if np.dtype(dtype).hasobject and "object_codec" not in kwargs:
import numcodecs

object_codec = numcodecs.Pickle()
kwargs["object_codec"] = object_codec

elif storage_name == "zarr-python-v3":
from cubed.storage.backends.zarr_python_v3 import open_zarr_v3_array

Expand Down
12 changes: 12 additions & 0 deletions cubed/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import pytest
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp
from cubed.storage.backend import backend_storage_name


# This is less strict than the spec, but is supported by implementations like NumPy
def test_prod_sum_bool():
a = xp.ones((2,), dtype=xp.bool)
assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64))
assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64))


@pytest.mark.skipif(
backend_storage_name() != "zarr-python",
reason="object dtype only works on zarr-python",
)
def test_object_dtype():
a = xp.asarray(["a", "b"], dtype=object, chunks=2)
cubed.to_zarr(a, store=None)

0 comments on commit 87db8ba

Please sign in to comment.