Skip to content

Commit 647d7e4

Browse files
shoyerXarray-Beam authors
authored andcommitted
Use RangeSource in ReadDatasets
PiperOrigin-RevId: 827046303
1 parent 1e5ddef commit 647d7e4

File tree

6 files changed

+442
-81
lines changed

6 files changed

+442
-81
lines changed

xarray_beam/_src/core.py

Lines changed: 164 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@
2727
import immutabledict
2828
import numpy as np
2929
import xarray
30+
from xarray_beam._src import range_source
3031
from xarray_beam._src import threadmap
3132

3233

33-
T = TypeVar('T')
34+
T = TypeVar("T")
3435

3536

3637
def export(obj: T) -> T:
37-
obj.__module__ = 'xarray_beam'
38+
obj.__module__ = "xarray_beam"
3839
return obj
3940

4041

@@ -122,7 +123,6 @@ class Key:
122123
Key(indices={'x': 4}, vars={'bar'})
123124
>>> key.with_indices(x=5)
124125
Key(indices={'x': 5}, vars={'bar'})
125-
126126
"""
127127

128128
# pylint: disable=redefined-builtin
@@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key:
184184
"""Replace some indices with new values.
185185
186186
Args:
187-
**indices: indices to override (for integer values) or remove, with
188-
values of ``None``.
187+
**indices: indices to override (for integer values) or remove, with values
188+
of ``None``.
189189
190190
Returns:
191191
New Key with the specified indices.
@@ -421,49 +421,19 @@ def normalize_expanded_chunks(
421421
)
422422

423423

424-
@export
425-
class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
426-
"""Split one or more xarray.Datasets into keyed chunks."""
424+
class _DatasetToChunksBase(beam.PTransform, Generic[DatasetOrDatasets]):
425+
"""Base class for PTransforms that split Datasets into chunks."""
427426

428427
def __init__(
429428
self,
430429
dataset: DatasetOrDatasets,
431430
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
432431
split_vars: bool = False,
433-
num_threads: int | None = None,
434-
shard_keys_threshold: int = 200_000,
435-
tasks_per_shard: int = 10_000,
436432
):
437-
"""Initialize DatasetToChunks.
438-
439-
Args:
440-
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
441-
[xarray.Dataset, ...]) pairs.
442-
chunks: optional chunking scheme. Required if the dataset is *not* already
443-
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
444-
precedence over the existing chunks.
445-
split_vars: whether to split the dataset into separate records for each
446-
data variable or to keep all data variables together. This is
447-
recommended if you don't need to perform joint operations on different
448-
dataset variables and individual variable chunks are sufficiently large.
449-
num_threads: optional number of Dataset chunks to load in parallel per
450-
worker. More threads can increase throughput, but also increases memory
451-
usage and makes it harder for Beam runners to shard work. Note that each
452-
variable in a Dataset is already loaded in parallel, so this is most
453-
useful for Datasets with a small number of variables or when using
454-
split_vars=True.
455-
shard_keys_threshold: threshold at which to compute keys on Beam workers,
456-
rather than only on the host process. This is important for scaling
457-
pipelines to millions of tasks.
458-
tasks_per_shard: number of tasks to emit per shard. Only used if the
459-
number of tasks exceeds shard_keys_threshold.
460-
"""
433+
"""Initialize _DatasetToChunksBase."""
461434
self.dataset = dataset
462435
self._validate(dataset, split_vars)
463436
self.split_vars = split_vars
464-
self.num_threads = num_threads
465-
self.shard_keys_threshold = shard_keys_threshold
466-
self.tasks_per_shard = tasks_per_shard
467437

468438
if chunks is None:
469439
dask_chunks = self._first.chunks
@@ -542,6 +512,77 @@ def _task_count(self) -> int:
542512
total += int(np.prod(count_list))
543513
return total
544514

515+
def _key_to_chunks(self, key: Key) -> tuple[Key, DatasetOrDatasets]:
516+
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
517+
with inc_timer_msec(self.__class__, "read-msec"):
518+
sizes = {
519+
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
520+
for dim, offset in key.offsets.items()
521+
}
522+
slices = offsets_to_slices(key.offsets, sizes)
523+
results = []
524+
for ds in self._datasets:
525+
dataset = ds if key.vars is None else ds[list(key.vars)]
526+
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
527+
chunk = dataset.isel(valid_slices)
528+
# Load the data, using a separate thread for each variable
529+
num_threads = len(dataset)
530+
result = chunk.chunk().compute(num_workers=num_threads)
531+
results.append(result)
532+
533+
inc_counter(self.__class__, "read-chunks")
534+
inc_counter(
535+
self.__class__, "read-bytes", sum(result.nbytes for result in results)
536+
)
537+
538+
if isinstance(self.dataset, xarray.Dataset):
539+
return key, results[0]
540+
else:
541+
return key, results
542+
543+
544+
@export
545+
class DatasetToChunks(_DatasetToChunksBase):
546+
"""Split one or more xarray.Datasets into keyed chunks."""
547+
548+
def __init__(
549+
self,
550+
dataset: DatasetOrDatasets,
551+
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
552+
split_vars: bool = False,
553+
num_threads: int | None = None,
554+
shard_keys_threshold: int = 200_000,
555+
tasks_per_shard: int = 10_000,
556+
):
557+
"""Initialize DatasetToChunks.
558+
559+
Args:
560+
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
561+
[xarray.Dataset, ...]) pairs.
562+
chunks: optional chunking scheme. Required if the dataset is *not* already
563+
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
564+
precedence over the existing chunks.
565+
split_vars: whether to split the dataset into separate records for each
566+
data variable or to keep all data variables together. This is
567+
recommended if you don't need to perform joint operations on different
568+
dataset variables and individual variable chunks are sufficiently large.
569+
num_threads: optional number of Dataset chunks to load in parallel per
570+
worker. More threads can increase throughput, but also increases memory
571+
usage and makes it harder for Beam runners to shard work. Note that each
572+
variable in a Dataset is already loaded in parallel, so this is most
573+
useful for Datasets with a small number of variables or when using
574+
split_vars=True.
575+
shard_keys_threshold: threshold at which to compute keys on Beam workers,
576+
rather than only on the host process. This is important for scaling
577+
pipelines to millions of tasks.
578+
tasks_per_shard: number of tasks to emit per shard. Only used if the
579+
number of tasks exceeds shard_keys_threshold.
580+
"""
581+
super().__init__(dataset, chunks, split_vars)
582+
self.num_threads = num_threads
583+
self.shard_keys_threshold = shard_keys_threshold
584+
self.tasks_per_shard = tasks_per_shard
585+
545586
@cached_property
546587
def sharded_dim(self) -> str | None:
547588
# We use the simple heuristic of only sharding inputs along the dimension
@@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
610651
inputs.append((None, name))
611652
return inputs # pytype: disable=bad-return-type # always-use-property-annotation
612653

613-
def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
614-
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
615-
with inc_timer_msec(self.__class__, "read-msec"):
616-
sizes = {
617-
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
618-
for dim, offset in key.offsets.items()
619-
}
620-
slices = offsets_to_slices(key.offsets, sizes)
621-
results = []
622-
for ds in self._datasets:
623-
dataset = ds if key.vars is None else ds[list(key.vars)]
624-
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
625-
chunk = dataset.isel(valid_slices)
626-
# Load the data, using a separate thread for each variable
627-
num_threads = len(dataset)
628-
result = chunk.chunk().compute(num_workers=num_threads)
629-
results.append(result)
630-
631-
inc_counter(self.__class__, "read-chunks")
632-
inc_counter(
633-
self.__class__, "read-bytes", sum(result.nbytes for result in results)
634-
)
635-
636-
if isinstance(self.dataset, xarray.Dataset):
637-
yield key, results[0]
638-
else:
639-
yield key, results
640-
641654
def expand(self, pcoll):
642655
if self.shard_count is None:
643656
# Create all keys on the machine launching the Beam pipeline. This is
@@ -652,11 +665,95 @@ def expand(self, pcoll):
652665
| beam.Reshuffle()
653666
)
654667

655-
return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap(
668+
return key_pcoll | "KeyToChunks" >> threadmap.ThreadMap(
656669
self._key_to_chunks, num_threads=self.num_threads
657670
)
658671

659672

673+
# TODO(shoyer): expose this function as a public API, after switching it to
674+
# generate Key objects using `indices` instead of `offsets`.
675+
class ReadDataset(_DatasetToChunksBase):
676+
"""Read chunks from an xarray.Dataset into a Beam pipeline.
677+
678+
This PTransform is a Beam "splittable DoFn", which means that it may be
679+
dynamically split by Beam runners into smaller chunks for efficient parallel
680+
execution.
681+
"""
682+
683+
def __init__(
684+
self,
685+
dataset: xarray.Dataset,
686+
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
687+
split_vars: bool = False,
688+
):
689+
"""Initialize ReadDatasets.
690+
691+
Args:
692+
dataset: dataset to split into (Key, xarray.Dataset) chunks.
693+
chunks: optional chunking scheme. Required if the dataset is *not* already
694+
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
695+
precedence over the existing chunks.
696+
split_vars: whether to split the dataset into separate records for each
697+
data variable or to keep all data variables together. This is
698+
recommended if you don't need to perform joint operations on different
699+
dataset variables and individual variable chunks are sufficiently large.
700+
"""
701+
super().__init__(dataset, chunks, split_vars)
702+
703+
@cached_property
704+
def _var_chunk_counts(
705+
self,
706+
) -> list[tuple[str | None, list[str], tuple[int, ...]]]:
707+
out = []
708+
if not self.split_vars:
709+
dims = sorted(self.expanded_chunks)
710+
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
711+
out.append((None, dims, shape))
712+
else:
713+
for name, variable in self._first.items():
714+
dims = sorted([d for d in variable.dims if d in self.expanded_chunks])
715+
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
716+
out.append((name, dims, shape))
717+
return out # pytype: disable=bad-return-type
718+
719+
@cached_property
720+
def _var_sizes(self) -> list[int]:
721+
return [int(np.prod(shape)) for _, _, shape in self._var_chunk_counts]
722+
723+
@cached_property
724+
def _cumulative_sizes(self) -> np.ndarray:
725+
return np.cumsum([0] + self._var_sizes)
726+
727+
def _index_to_key(self, position: int) -> Key:
728+
assert 0 <= position < self._cumulative_sizes[-1]
729+
var_index = (
730+
np.searchsorted(self._cumulative_sizes, position, side="right") - 1
731+
)
732+
offset = position - self._cumulative_sizes[var_index]
733+
name, dims, shape = self._var_chunk_counts[var_index]
734+
indices = np.unravel_index(offset, shape)
735+
offsets = {dim: self.offsets[dim][idx] for dim, idx in zip(dims, indices)}
736+
return Key(offsets, vars=None if name is None else {name})
737+
738+
def _get_element(self, position: int) -> tuple[Key, xarray.Dataset]:
739+
return self._key_to_chunks(self._index_to_key(position)) # pytype: disable=bad-return-type
740+
741+
def expand(
742+
self, pbegin: beam.PBegin
743+
) -> beam.PCollection[tuple[Key, xarray.Dataset]]:
744+
element_count = self._task_count()
745+
assert element_count > 0
746+
# For simplicity, assume that all chunks are approximately the same size,
747+
# even if variables are being split and some variables have different
748+
# variables. This assumption could be relaxed in the future, with an
749+
# improved version of RangeSource.
750+
avg_chunk_bytes = math.ceil(self._first.nbytes / element_count)
751+
source = range_source.RangeSource(
752+
element_count, avg_chunk_bytes, self._get_element
753+
)
754+
return pbegin | beam.io.Read(source)
755+
756+
660757
def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None:
661758
"""Ensure that a dataset contains no chunked variables."""
662759
for var_name, variable in dataset.variables.items():

0 commit comments

Comments
 (0)