diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 66f1b3b21..540d3cf22 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -9,9 +9,10 @@ import xarray as xr import zarr from icechunk import IcechunkStore, Session +from icechunk.distributed import merge_sessions from icechunk.session import ForkSession from icechunk.vendor.xarray import _choose_default_mode -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, DataTree from xarray.backends.common import ArrayWriter from xarray.backends.zarr import ZarrStore @@ -23,8 +24,13 @@ try: has_dask = importlib.util.find_spec("dask") is not None + if has_dask: + from dask.highlevelgraph import HighLevelGraph + else: + HighLevelGraph = None except ImportError: has_dask = False + HighLevelGraph = None if Version(xr.__version__) < Version("2024.10.0"): raise ValueError( @@ -44,11 +50,39 @@ def is_dask_collection(x: Any) -> bool: if has_dask: import dask - return dask.base.is_dask_collection(x) + if isinstance(x, DataTree): + return bool(datatree_dask_graph(x)) + else: + return dask.base.is_dask_collection(x) else: return False +def datatree_dask_graph(dt: DataTree) -> "HighLevelGraph | None": # type: ignore[name-defined] + # copied from `Dataset.__dask_graph__()`. + # Should ideally be upstreamed into xarray as part of making DataTree a true dask collection - see https://github.com/pydata/xarray/issues/9355. + + all_variables = { + f"{path}/{var_name}" if path != "." else var_name: variable + for path, node in dt.subtree_with_keys + for var_name, variable in node.variables.items() + } + + graphs = {k: v.__dask_graph__() for k, v in all_variables.items()} + graphs = {k: v for k, v in graphs.items() if v is not None} + if not graphs: + return None + else: + try: + from dask.highlevelgraph import HighLevelGraph + + return HighLevelGraph.merge(*graphs.values()) + except ImportError: + from dask import sharedict + + return sharedict.merge(*graphs.values()) + + class LazyArrayWriter(ArrayWriter): def __init__(self) -> None: super().__init__() # type: ignore[no-untyped-call] @@ -188,6 +222,48 @@ def write_lazy( return session_merge_reduction(stored_arrays, split_every=split_every) +def write_ds( + ds, + store, + safe_chunks, + group, + mode, + append_dim, + region, + encoding, + chunkmanager_store_kwargs, +) -> ForkSession | None: + writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks) + writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) + + # write metadata + writer.write_metadata(encoding) + # write in-memory arrays + writer.write_eager() + # eagerly write dask arrays + maybe_fork_session = writer.write_lazy( + chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) + + return maybe_fork_session + + +# overload because several kwargs are currently forbidden for DataTree, and ``write_inherited_coords`` only applies to DataTree +@overload +def to_icechunk( + obj: DataTree, + session: Session, + *, + mode: ZarrWriteModes | None = None, + safe_chunks: bool = True, + encoding: Mapping[Any, Any] | None = None, + chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + write_inherited_coords: bool = False, + split_every: int | None = None, +) -> None: ... + + +@overload def to_icechunk( obj: DataArray | Dataset, session: Session, @@ -200,14 +276,32 @@ def to_icechunk( encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, split_every: int | None = None, +) -> None: ... + + +def to_icechunk( + obj: DataArray | Dataset | DataTree, + session: Session, + *, + group: str | None = None, + mode: ZarrWriteModes | None = None, + safe_chunks: bool = True, + append_dim: Hashable | None = None, + region: Region = None, + encoding: Mapping[Any, Any] | None = None, + chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + write_inherited_coords: bool = False, + split_every: int | None = None, ) -> None: """ Write an Xarray object to a group of an Icechunk store. Parameters ---------- - obj: DataArray or Dataset - Xarray object to write + obj: DataArray, Dataset, or DataTree + Xarray object to write. + + Note: When passing a DataTree, the ``append_dim``, ``region``, and ``group`` parameters are not yet supported. session : icechunk.Session Writable Icechunk Session mode : {"w", "w-", "a", "a-", r+", None}, optional @@ -265,6 +359,11 @@ def to_icechunk( Additional keyword arguments passed on to the `ChunkManager.store` method used to store chunked arrays. For example for a dask array additional kwargs will be passed eventually to `dask.array.store()`. Experimental API that should not be relied upon. + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. split_every: int, optional Number of tasks to merge at every level of the tree reduction. @@ -283,8 +382,20 @@ def to_icechunk( ``append_dim`` at the same time. To create empty arrays to fill in with ``region``, use the `_XarrayDatasetWriter` directly. """ - - as_dataset = _make_dataset(obj) + # Validate parameters for DataTree + if isinstance(obj, DataTree): + if group is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + if append_dim is not None: + raise NotImplementedError( + "The 'append_dim' parameter is not yet supported when writing DataTree objects." + ) + if region is not None: + raise NotImplementedError( + "The 'region' parameter is not yet supported when writing DataTree objects." + ) # This ugliness is needed so that we allow users to call `to_icechunk` with a dirty Session # for _serial_ writes @@ -299,18 +410,54 @@ def to_icechunk( else: fork = session - writer = _XarrayDatasetWriter(as_dataset, store=fork.store, safe_chunks=safe_chunks) + if isinstance(obj, DataTree): + dt = obj - writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) + if encoding is None: + encoding = {} + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + maybe_forked_sessions: list[ForkSession | None] = [] + for rel_path, node in dt.subtree_with_keys: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) + + maybe_fork_session = write_ds( + ds=dataset, + store=fork.store, + safe_chunks=safe_chunks, + group=dt[rel_path].path, + mode=mode, + append_dim=append_dim, + region=region, + encoding=encoding, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, + ) + maybe_forked_sessions.append(maybe_fork_session) + + if any(maybe_forked_sessions) and is_dask: + # Note: This should be safe since each iteration of the loop writes to a different group, so there are no conflicts. + maybe_fork_session = merge_sessions(maybe_forked_sessions) + else: + maybe_fork_session = None + + else: + as_dataset = _make_dataset(obj) + maybe_fork_session = write_ds( + ds=as_dataset, + store=fork.store, + safe_chunks=safe_chunks, + group=group, + mode=mode, + append_dim=append_dim, + region=region, + encoding=encoding, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, + ) - # write metadata - writer.write_metadata(encoding) - # write in-memory arrays - writer.write_eager() - # eagerly write dask arrays - maybe_fork_session = writer.write_lazy( - chunkmanager_store_kwargs=chunkmanager_store_kwargs - ) if is_dask: if maybe_fork_session is None: raise RuntimeError( diff --git a/icechunk-python/tests/test_xarray.py b/icechunk-python/tests/test_xarray.py index 54f236eb2..b0cf7bb41 100644 --- a/icechunk-python/tests/test_xarray.py +++ b/icechunk-python/tests/test_xarray.py @@ -49,6 +49,48 @@ def create_test_data( return obj +def create_test_datatree() -> xr.DataTree: + return xr.DataTree.from_dict( + { + "/": xr.Dataset( + data_vars={ + "bar": ("x", ["hello", "world"]), + }, + coords={ + "x": ( + "x", + [1, 2], + ), # inherited dimension coordinate that can't be overridden + "w": ( + "x", + [0.1, 0.2], + ), # inherited non-dimension coordinate to override + }, + ), + "/a": xr.Dataset( + data_vars={ + "foo": ("x", ["alpha", "beta"]), + }, + coords={ + "w": ("x", [10, 20]), # override inherited non-dimension coordinate + "z": ("z", ["alpha", "beta"]), # non-inherited dimension coordinate + }, + ), + "/b": xr.Dataset( + data_vars={ + "foo": ("x", ["gamma", "delta"]), + }, + coords={ + "z": ( + "z", + ["alpha", "beta", "gamma"], + ), # override inherited non-dimension coordinate with different length (i.e. multi-resolution) + }, + ), + } + ) + + @contextlib.contextmanager def roundtrip( data: xr.Dataset, *, commit: bool = False @@ -62,12 +104,31 @@ def roundtrip( yield ds -def test_xarray_to_icechunk() -> None: +def test_xarray_dataset_to_icechunk() -> None: ds = create_test_data() with roundtrip(ds) as actual: assert_identical(actual, ds) +@contextlib.contextmanager +def roundtrip_datatree( + dt: xr.DataTree, *, commit: bool = False +) -> Generator[xr.DataTree, None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + repo = Repository.create(local_filesystem_storage(tmpdir)) + session = repo.writable_session("main") + to_icechunk(dt, session=session, mode="w") + session.commit("write") + with xr.open_datatree(session.store, consolidated=False, engine="zarr") as dt: + yield dt + + +def test_xarray_datatree_to_icechunk() -> None: + dt = create_test_datatree() + with roundtrip_datatree(dt) as actual: + assert_identical(actual, dt) + + def test_repeated_to_icechunk_serial() -> None: ds = create_test_data() repo = Repository.create(in_memory_storage())