Skip to content

Commit aa0f087

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add Dataset.from_ptransform
This is a variant of the Dataset constructor with extensive validation. Also add documentation explaining how it works. PiperOrigin-RevId: 814972825
1 parent e60b6ad commit aa0f087

File tree

6 files changed

+541
-38
lines changed

6 files changed

+541
-38
lines changed

docs/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ guarantees.
9292
:toctree: _autosummary
9393
9494
Dataset
95-
Dataset.from_xarray
9695
Dataset.from_zarr
9796
Dataset.to_zarr
97+
Dataset.from_xarray
98+
Dataset.from_ptransform
9899
Dataset.collect_with_direct_runner
99100
Dataset.map_blocks
100101
Dataset.rechunk

docs/high-level.ipynb

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"xbeam_ds"
6464
],
6565
"outputs": [],
66-
"execution_count": 2
66+
"execution_count": 1
6767
},
6868
{
6969
"metadata": {
@@ -83,7 +83,7 @@
8383
"xarray_ds.chunk(chunks).to_zarr('example_data.zarr', mode='w')"
8484
],
8585
"outputs": [],
86-
"execution_count": 3
86+
"execution_count": 2
8787
},
8888
{
8989
"metadata": {
@@ -186,7 +186,7 @@
186186
"xarray.open_zarr('example_climatology.zarr')"
187187
],
188188
"outputs": [],
189-
"execution_count": 6
189+
"execution_count": 3
190190
},
191191
{
192192
"metadata": {
@@ -215,7 +215,7 @@
215215
"xarray.open_zarr('example_regrid.zarr')"
216216
],
217217
"outputs": [],
218-
"execution_count": 7
218+
"execution_count": 4
219219
},
220220
{
221221
"metadata": {
@@ -245,15 +245,15 @@
245245
" print(f'{type(e).__name__}: {e}')"
246246
],
247247
"outputs": [],
248-
"execution_count": 8
248+
"execution_count": 5
249249
},
250250
{
251251
"metadata": {
252252
"id": "vCjZK9fmEeEq"
253253
},
254254
"cell_type": "markdown",
255255
"source": [
256-
"You can avoid these errors by explicitly [creating a template](creating_templates):"
256+
"You can avoid these errors by explicitly supplying a template, either from {py:attr}`Dataset.template \u003cxarray_beam.Dataset.template\u003e` or produced by {py:func}`~xarray_beam.make_template`:"
257257
]
258258
},
259259
{
@@ -262,19 +262,70 @@
262262
},
263263
"cell_type": "code",
264264
"source": [
265-
"ds_beam = xbeam.Dataset.from_zarr('example_data.zarr')\n",
266-
"ds_beam.map_blocks(lambda ds: ds.compute(), template=ds_beam.template)"
265+
"template = xbeam.make_template(xarray_ds)\n",
266+
"(\n",
267+
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
268+
" .map_blocks(lambda ds: ds.compute(), template=template)\n",
269+
")"
267270
],
268271
"outputs": [],
269-
"execution_count": 9
272+
"execution_count": 6
273+
},
274+
{
275+
"metadata": {
276+
"id": "-U4t0kKIkDvb"
277+
},
278+
"cell_type": "markdown",
279+
"source": [
280+
"## Interfacing with low-level transforms"
281+
]
270282
},
271283
{
272284
"metadata": {
273285
"id": "75IG-22cKcuE"
274286
},
275287
"cell_type": "markdown",
276288
"source": [
277-
"Sometimes, your computation doesn't fit into the ``map_blocks`` paradigm because you don't want to create `xarray.Dataset` objects. For these cases, you can switch to the lower-level Xarray-Beam [data model](data-model), and use raw Beam operations:"
289+
"`Dataset` is a thin wrapper around Xarray-Beam transformations, so you can always drop into the lower-level Xarray-Beam [data model](data-model) and use raw Beam operations. This is especially useful for the reading or writing data.\n",
290+
"\n",
291+
"To manually create a `Dataset` from a Beam ptransform, use {py:meth}`~xarray_beam.Dataset.from_ptransform`. Here's an example, showing the common pattern of evaluating a single example in-memory to produce the `xarray.Dataset` required for building a template:"
292+
]
293+
},
294+
{
295+
"metadata": {
296+
"id": "l9pHS1QDlMd-"
297+
},
298+
"cell_type": "code",
299+
"source": [
300+
"all_times = pd.date_range('2025-01-01', freq='1D', periods=365)\n",
301+
"source_dataset = xarray.open_zarr('example_data.zarr', chunks=None)\n",
302+
"\n",
303+
"def load_chunk(time: pd.Timestamp) -\u003e tuple[xbeam.Key, xarray.Dataset]:\n",
304+
" key = xbeam.Key({'time': (time - all_times[0]).days})\n",
305+
" dataset = source_dataset.sel(time=[time])\n",
306+
" return key, dataset\n",
307+
"\n",
308+
"ptransform = beam.Create(all_times) | beam.Map(load_chunk)\n",
309+
"\n",
310+
"_, example = load_chunk(all_times[0])\n",
311+
"template = xbeam.make_template(example)\n",
312+
"template = xbeam.replace_template_dims(template, time=all_times)\n",
313+
"\n",
314+
"ds_beam = xbeam.Dataset.from_ptransform(\n",
315+
" ptransform, template=template, chunks={'time': 1}, split_vars=False\n",
316+
")\n",
317+
"ds_beam"
318+
],
319+
"outputs": [],
320+
"execution_count": 12
321+
},
322+
{
323+
"metadata": {
324+
"id": "1qjeY5mwlLGJ"
325+
},
326+
"cell_type": "markdown",
327+
"source": [
328+
"You can also pull-out the underlying Beam `ptransform` from a dataset to append new transformations, e.g., to write each element of the pipeline to disk as a separate file:"
278329
]
279330
},
280331
{
@@ -288,16 +339,12 @@
288339
" chunk.to_netcdf(path)\n",
289340
"\n",
290341
"with beam.Pipeline() as p:\n",
291-
" p | (\n",
292-
" xbeam.Dataset.from_zarr('example_data.zarr')\n",
293-
" .rechunk({'latitude': -1, 'longitude': -1})\n",
294-
" .ptransform\n",
295-
" ) | beam.MapTuple(to_netcdf)\n",
342+
" p | ds_beam.rechunk('50MB').ptransform | beam.MapTuple(to_netcdf)\n",
296343
"\n",
297344
"%ls *.nc"
298345
],
299346
"outputs": [],
300-
"execution_count": 10
347+
"execution_count": 13
301348
}
302349
],
303350
"metadata": {

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.10.3' # automatically synchronized to pyproject.toml
58+
__version__ = '0.10.4' # automatically synchronized to pyproject.toml

xarray_beam/_src/core.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -505,24 +505,29 @@ def expand(self, pcoll):
505505
)
506506

507507

508+
def _ensure_chunk_is_computed(key: Key,dataset: xarray.Dataset) -> None:
509+
"""Ensure that a dataset contains no chunked variables."""
510+
for var_name, variable in dataset.variables.items():
511+
if variable.chunks is not None:
512+
raise ValueError(
513+
f"Dataset variable {var_name!r} corresponding to key {key} is"
514+
" chunked with Dask. Datasets passed to validate_chunk must be"
515+
f" fully computed (not chunked): {dataset}\nThis typically arises"
516+
" with datasets originating with `xarray.open_zarr()`, which by"
517+
" default use Dask. If this is the case, you can fix it by passing"
518+
" `chunks=None` or xarray_beam.open_zarr(). Alternatively, you"
519+
" can load datasets explicitly into memory with `.compute()`."
520+
)
521+
522+
508523
def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
509524
"""Verify that a key and dataset(s) are valid for xarray-beam transforms."""
510525
if isinstance(datasets, xarray.Dataset):
511526
datasets: list[xarray.Dataset] = [datasets]
512527

513528
for dataset in datasets:
514529
# Verify that no variables are chunked with Dask
515-
for var_name, variable in dataset.variables.items():
516-
if variable.chunks is not None:
517-
raise ValueError(
518-
f"Dataset variable {var_name!r} corresponding to key {key} is"
519-
" chunked with Dask. Datasets passed to validate_chunk must be"
520-
f" fully computed (not chunked): {dataset}\nThis typically arises"
521-
" with datasets originating with `xarray.open_zarr()`, which by"
522-
" default use Dask. If this is the case, you can fix it by passing"
523-
" `chunks=None` or xarray_beam.open_zarr(). Alternatively, you"
524-
" can load datasets explicitly into memory with `.compute()`."
525-
)
530+
_ensure_chunk_is_computed(key, dataset)
526531

527532
# Validate key offsets
528533
missing_keys = [

xarray_beam/_src/dataset.py

Lines changed: 151 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,98 @@ def _infer_new_chunks(
227227
return new_chunks
228228

229229

230+
def _normalize_and_validate_chunk(
231+
template: xarray.Dataset,
232+
chunks: Mapping[str, int],
233+
split_vars: bool,
234+
key: core.Key,
235+
dataset: xarray.Dataset,
236+
) -> tuple[core.Key, xarray.Dataset]:
237+
"""Validate and normalize (key, dataset) pairs for a Dataset."""
238+
239+
if split_vars:
240+
if key.vars is None:
241+
key = key.replace(vars=set(dataset.keys()))
242+
elif key.vars != set(dataset.keys()):
243+
raise ValueError(
244+
f'dataset keys {sorted(dataset.keys())} do not match'
245+
f' key.vars={sorted(key.vars)}'
246+
)
247+
elif key.vars is not None:
248+
raise ValueError(f'must not set vars on key if split_vars=False: {key}')
249+
250+
new_offsets = dict(key.offsets)
251+
for dim in dataset.dims:
252+
if dim not in new_offsets:
253+
new_offsets[dim] = 0
254+
if len(new_offsets) != len(key.offsets):
255+
key = key.replace(offsets=new_offsets)
256+
257+
core._ensure_chunk_is_computed(key, dataset)
258+
259+
def _with_dataset(msg: str):
260+
dataset_repr = textwrap.indent(repr(dataset), prefix=' ')
261+
return f'{msg}\nKey: {key}\nDataset chunk:\n{dataset_repr}'
262+
263+
def _bad_template_error(msg: str):
264+
template_repr = textwrap.indent(repr(template), prefix=' ')
265+
raise ValueError(_with_dataset(msg) + f'Template:\n{template_repr}')
266+
267+
for k, v in dataset.items():
268+
if k not in template:
269+
_bad_template_error(
270+
f'Chunk variable {k!r} not found in template variables '
271+
f' {list(template.data_vars)}:'
272+
)
273+
if v.dtype != template[k].dtype:
274+
_bad_template_error(
275+
f'Chunk variable {k!r} has dtype {v.dtype} which does not match'
276+
f' template variable dtype {template[k].dtype}:'
277+
)
278+
if v.dims != template[k].dims:
279+
_bad_template_error(
280+
f'Chunk variable {k!r} has dims {v.dims} which does not match'
281+
f' template variable dims {template[k].dims}:'
282+
)
283+
284+
for dim, size in dataset.sizes.items():
285+
if dim not in chunks:
286+
raise ValueError(
287+
_with_dataset(
288+
f'Dataset dimension {dim!r} not found in chunks {chunks}:'
289+
)
290+
)
291+
offset = key.offsets[dim]
292+
if offset % chunks[dim] != 0:
293+
raise ValueError(
294+
_with_dataset(
295+
f'Chunk offset {offset} is not aligned with chunk '
296+
f'size {chunks[dim]} for dimension {dim!r}:'
297+
)
298+
)
299+
if offset + size > template.sizes[dim]:
300+
_bad_template_error(
301+
f'Chunk dimension {dim!r} has size {size} which is larger than the '
302+
f'remaining size {template.sizes[dim] - offset} in the '
303+
'template:'
304+
)
305+
is_last_chunk = offset + chunks[dim] > template.sizes[dim]
306+
if is_last_chunk:
307+
expected_size = template.sizes[dim] - offset
308+
if size != expected_size:
309+
_bad_template_error(
310+
f'Chunk dimension {dim!r} is the last chunk, but has size {size} '
311+
f'which does not match expected size {expected_size}:'
312+
)
313+
elif size != chunks[dim]:
314+
_bad_template_error(
315+
f'Chunk dimension {dim!r} has size {size} which does not match'
316+
f' chunk size {chunks[dim]}:'
317+
)
318+
319+
return key, dataset
320+
321+
230322
def _apply_to_each_chunk(
231323
func: Callable[[xarray.Dataset], xarray.Dataset],
232324
old_chunks: Mapping[str, int],
@@ -302,9 +394,8 @@ def __init__(
302394
):
303395
"""Low level interface for creating a new Dataset, without validation.
304396
305-
Most users should use the higher level
306-
:py:class:`xarray_beam.Dataset.from_xarray` or
307-
:py:class:`xarray_beam.Dataset.from_zarr` instead.
397+
Unless you're really sure you don't need validation, prefer using
398+
:py:class:`xarray_beam.Dataset.from_ptransform`.
308399
309400
Args:
310401
template: xarray.Dataset describing the structure of this dataset,
@@ -317,9 +408,7 @@ def __init__(
317408
this dataset's data.
318409
"""
319410
self._template = template
320-
self._chunks = {
321-
k: min(template.sizes[k], v) for k, v in chunks.items()
322-
}
411+
self._chunks = {k: min(template.sizes[k], v) for k, v in chunks.items()}
323412
self._split_vars = split_vars
324413
self._ptransform = ptransform
325414

@@ -390,6 +479,62 @@ def __repr__(self):
390479
+ textwrap.indent('\n'.join(base.split('\n')[1:]), ' ' * 4)
391480
)
392481

482+
@classmethod
483+
def from_ptransform(
484+
cls,
485+
ptransform: beam.PTransform,
486+
*,
487+
template: xarray.Dataset,
488+
chunks: Mapping[str | types.EllipsisType, int],
489+
split_vars: bool = False,
490+
) -> Dataset:
491+
"""Create an xarray_beam.Dataset from a Beam PTransform.
492+
493+
This is an advanced constructor that allows you to create an
494+
``xarray_beam.Dataset`` from an existing Beam PTransform that produces
495+
``(Key, xarray.Dataset)`` pairs.
496+
497+
The PTransform should produce chunks that conform to the given ``template``,
498+
``chunks``, and ``split_vars`` arguments. This constructor will add a
499+
validation step to the PTransform to normalize keys into the strictest
500+
possible form based on the other arguments, and ensure that transform
501+
outputs are valid.
502+
503+
Args:
504+
ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs.
505+
You only need to set ``offsets`` on these keys, ``vars`` will be
506+
automatically set based on the dataset if ``split_vars`` is True.
507+
template: An ``xarray.Dataset`` object representing the schema
508+
(coordinates, dimensions, data variables, and attributes) of the full
509+
dataset, as produced by :py:func:`xarray_beam.make_template`, with data
510+
variables backed by Dask arrays.
511+
chunks: A dictionary mapping dimension names to integer chunk sizes. Every
512+
chunk produced by ``ptransform`` must have dimensions of these sizes,
513+
except for the last chunk in each dimension, which may be smaller.
514+
split_vars: A boolean indicating whether the chunks in ``ptransform`` are
515+
split across variables, or if each chunk contains all variables.
516+
517+
Returns:
518+
An ``xarray_beam.Dataset`` instance wrapping the PTransform.
519+
"""
520+
if not isinstance(chunks, Mapping):
521+
raise TypeError(
522+
f'chunks must be a mapping for from_ptransform, got {chunks}'
523+
)
524+
for v in chunks.values():
525+
if not isinstance(v, int):
526+
raise TypeError(
527+
'chunks must be a mapping with integer values for from_ptransform,'
528+
f' got {chunks}'
529+
)
530+
chunks = normalize_chunks(chunks, template)
531+
ptransform = ptransform | _get_label("validate") >> beam.MapTuple(
532+
functools.partial(
533+
_normalize_and_validate_chunk, template, chunks, split_vars
534+
)
535+
)
536+
return cls(template, chunks, split_vars, ptransform)
537+
393538
@classmethod
394539
def from_xarray(
395540
cls,

0 commit comments

Comments
 (0)