Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 127 additions & 16 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from icechunk import IcechunkStore, Session
from icechunk.session import ForkSession
from icechunk.vendor.xarray import _choose_default_mode
from xarray import DataArray, Dataset
from xarray import DataArray, Dataset, DataTree
from xarray.backends.common import ArrayWriter
from xarray.backends.zarr import ZarrStore

Expand All @@ -32,7 +32,10 @@
)

if Version(xr.__version__) > Version("2025.09.0"):
from xarray.backends.writers import _validate_dataset_names, dump_to_store # type: ignore[import-not-found]
from xarray.backends.writers import ( # type: ignore[import-not-found]
_validate_dataset_names,
dump_to_store,
)
else:
from xarray.backends.api import _validate_dataset_names, dump_to_store

Expand Down Expand Up @@ -185,6 +188,48 @@ def write_lazy(
return session_merge_reduction(stored_arrays, split_every=split_every)


def write_ds(
ds,
store,
safe_chunks,
group,
mode,
append_dim,
region,
encoding,
chunkmanager_store_kwargs,
):
writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks)
writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region)

# write metadata
writer.write_metadata(encoding)
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
maybe_fork_session = writer.write_lazy(
chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

return maybe_fork_session


# overload because several kwargs are currently forbidden for DataTree, and ``write_inherited_coords`` only applies to DataTree
@overload
def to_icechunk(
obj: DataTree,
session: Session,
*,
mode: ZarrWriteModes | None = None,
safe_chunks: bool = True,
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
write_inherited_coords: bool = False,
split_every: int | None = None,
) -> None: ...


@overload
def to_icechunk(
obj: DataArray | Dataset,
session: Session,
Expand All @@ -197,14 +242,32 @@ def to_icechunk(
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
) -> None: ...


def to_icechunk(
obj: DataArray | Dataset | DataTree,
session: Session,
*,
group: str | None = None,
mode: ZarrWriteModes | None = None,
safe_chunks: bool = True,
append_dim: Hashable | None = None,
region: Region = None,
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
write_inherited_coords: bool = False,
split_every: int | None = None,
) -> None:
"""
Write an Xarray object to a group of an Icechunk store.

Parameters
----------
obj: DataArray or Dataset
Xarray object to write
obj: DataArray, Dataset, or DataTree
Xarray object to write.

Note: When passing a DataTree, the ``append_dim``, ``region``, and ``group`` parameters are not yet supported.
session : icechunk.Session
Writable Icechunk Session
mode : {"w", "w-", "a", "a-", r+", None}, optional
Expand Down Expand Up @@ -262,6 +325,11 @@ def to_icechunk(
Additional keyword arguments passed on to the `ChunkManager.store` method used to store
chunked arrays. For example for a dask array additional kwargs will be passed eventually to
`dask.array.store()`. Experimental API that should not be relied upon.
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
split_every: int, optional
Number of tasks to merge at every level of the tree reduction.

Expand All @@ -280,11 +348,25 @@ def to_icechunk(
``append_dim`` at the same time. To create empty arrays to fill
in with ``region``, use the `_XarrayDatasetWriter` directly.
"""

as_dataset = _make_dataset(obj)
# Validate parameters for DataTree
if isinstance(obj, DataTree):
if group is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)
if append_dim is not None:
raise NotImplementedError(
"The 'append_dim' parameter is not yet supported when writing DataTree objects."
)
if region is not None:
raise NotImplementedError(
"The 'region' parameter is not yet supported when writing DataTree objects."
)

# This ugliness is needed so that we allow users to call `to_icechunk` with a dirty Session
# for _serial_ writes

# TODO DataTree does not implement `__dask_graph__`, unlike `Dataset`, so will this ever trigger?
is_dask = is_dask_collection(obj)
fork: Session | ForkSession
if is_dask:
Expand All @@ -296,18 +378,47 @@ def to_icechunk(
else:
fork = session

writer = _XarrayDatasetWriter(as_dataset, store=fork.store, safe_chunks=safe_chunks)
if isinstance(obj, DataTree):
dt = obj

writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region)
if encoding is None:
encoding = {}
if set(encoding) - set(dt.groups):
raise ValueError(
f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
)

for rel_path, node in dt.subtree_with_keys:
at_root = node is dt
dataset = node.to_dataset(inherit=write_inherited_coords or at_root)

# TODO what do I do with all these maybe_fork_sessions here?
maybe_fork_session = write_ds(
ds=dataset,
store=fork.store,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for the future: this should be safe since each iteration of the loop writes to different group, so there are no conflicts.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this as a comment in 7b0fec5

safe_chunks=safe_chunks,
group=dt[rel_path].path,
mode=mode,
append_dim=append_dim,
region=region,
encoding=encoding,
chunkmanager_store_kwargs=chunkmanager_store_kwargs,
)

else:
as_dataset = _make_dataset(obj)
maybe_fork_session = write_ds(
ds=as_dataset,
store=fork.store,
safe_chunks=safe_chunks,
group=group,
mode=mode,
append_dim=append_dim,
region=region,
encoding=encoding,
chunkmanager_store_kwargs=chunkmanager_store_kwargs,
)

# write metadata
writer.write_metadata(encoding)
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
maybe_fork_session = writer.write_lazy(
chunkmanager_store_kwargs=chunkmanager_store_kwargs
)
if is_dask:
if maybe_fork_session is None:
raise RuntimeError(
Expand Down
63 changes: 62 additions & 1 deletion icechunk-python/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ def create_test_data(
return obj


def create_test_datatree() -> xr.DataTree:
return xr.DataTree.from_dict(
{
"/": xr.Dataset(
data_vars={
"bar": ("x", ["hello", "world"]),
},
coords={
"x": (
"x",
[1, 2],
), # inherited dimension coordinate that can't be overridden
"w": (
"x",
[0.1, 0.2],
), # inherited non-dimension coordinate to override
},
),
"/a": xr.Dataset(
data_vars={
"foo": ("x", ["alpha", "beta"]),
},
coords={
"w": ("x", [10, 20]), # override inherited non-dimension coordinate
"z": ("z", ["alpha", "beta"]), # non-inherited dimension coordinate
},
),
"/b": xr.Dataset(
data_vars={
"foo": ("x", ["gamma", "delta"]),
},
coords={
"z": (
"z",
["alpha", "beta", "gamma"],
), # override inherited non-dimension coordinate with different length (i.e. multi-resolution)
},
),
}
)


@contextlib.contextmanager
def roundtrip(
data: xr.Dataset, *, commit: bool = False
Expand All @@ -62,12 +104,31 @@ def roundtrip(
yield ds


def test_xarray_to_icechunk() -> None:
def test_xarray_dataset_to_icechunk() -> None:
ds = create_test_data()
with roundtrip(ds) as actual:
assert_identical(actual, ds)


@contextlib.contextmanager
def roundtrip_datatree(
dt: xr.DataTree, *, commit: bool = False
) -> Generator[xr.DataTree, None, None]:
with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository.create(local_filesystem_storage(tmpdir))
session = repo.writable_session("main")
to_icechunk(dt, session=session, mode="w")
session.commit("write")
with xr.open_datatree(session.store, consolidated=False, engine="zarr") as dt:
yield dt


def test_xarray_datatree_to_icechunk() -> None:
dt = create_test_datatree()
with roundtrip_datatree(dt) as actual:
assert_identical(actual, dt)


def test_repeated_to_icechunk_serial() -> None:
ds = create_test_data()
repo = Repository.create(in_memory_storage())
Expand Down
Loading