diff --git a/xarray_beam/_src/core.py b/xarray_beam/_src/core.py index 1ec3895..3aa80d9 100644 --- a/xarray_beam/_src/core.py +++ b/xarray_beam/_src/core.py @@ -16,7 +16,7 @@ from collections.abc import Hashable, Iterator, Mapping, Sequence, Set import contextlib -from functools import cached_property +import functools import itertools import math import pickle @@ -27,14 +27,15 @@ import immutabledict import numpy as np import xarray +from xarray_beam._src import range_source from xarray_beam._src import threadmap -T = TypeVar('T') +T = TypeVar("T") def export(obj: T) -> T: - obj.__module__ = 'xarray_beam' + obj.__module__ = "xarray_beam" return obj @@ -122,7 +123,6 @@ class Key: Key(indices={'x': 4}, vars={'bar'}) >>> key.with_indices(x=5) Key(indices={'x': 5}, vars={'bar'}) - """ # pylint: disable=redefined-builtin @@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key: """Replace some indices with new values. Args: - **indices: indices to override (for integer values) or remove, with - values of ``None``. + **indices: indices to override (for integer values) or remove, with values + of ``None``. Returns: New Key with the specified indices. @@ -421,49 +421,19 @@ def normalize_expanded_chunks( ) -@export -class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]): - """Split one or more xarray.Datasets into keyed chunks.""" +class _DatasetToChunksBase(beam.PTransform, Generic[DatasetOrDatasets]): + """Base class for PTransforms that split Datasets into chunks.""" def __init__( self, dataset: DatasetOrDatasets, chunks: Mapping[str, int | tuple[int, ...]] | None = None, split_vars: bool = False, - num_threads: int | None = None, - shard_keys_threshold: int = 200_000, - tasks_per_shard: int = 10_000, ): - """Initialize DatasetToChunks. - - Args: - dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key, - [xarray.Dataset, ...]) pairs. - chunks: optional chunking scheme. Required if the dataset is *not* already - chunked. If the dataset *is* already chunked with Dask, `chunks` takes - precedence over the existing chunks. - split_vars: whether to split the dataset into separate records for each - data variable or to keep all data variables together. This is - recommended if you don't need to perform joint operations on different - dataset variables and individual variable chunks are sufficiently large. - num_threads: optional number of Dataset chunks to load in parallel per - worker. More threads can increase throughput, but also increases memory - usage and makes it harder for Beam runners to shard work. Note that each - variable in a Dataset is already loaded in parallel, so this is most - useful for Datasets with a small number of variables or when using - split_vars=True. - shard_keys_threshold: threshold at which to compute keys on Beam workers, - rather than only on the host process. This is important for scaling - pipelines to millions of tasks. - tasks_per_shard: number of tasks to emit per shard. Only used if the - number of tasks exceeds shard_keys_threshold. - """ + """Initialize _DatasetToChunksBase.""" self.dataset = dataset self._validate(dataset, split_vars) self.split_vars = split_vars - self.num_threads = num_threads - self.shard_keys_threshold = shard_keys_threshold - self.tasks_per_shard = tasks_per_shard if chunks is None: dask_chunks = self._first.chunks @@ -489,15 +459,15 @@ def _datasets(self) -> list[xarray.Dataset]: return [self.dataset] return list(self.dataset) # pytype: disable=bad-return-type - @cached_property + @functools.cached_property def expanded_chunks(self) -> dict[str, tuple[int, ...]]: return normalize_expanded_chunks(self.chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation - @cached_property + @functools.cached_property def offsets(self) -> dict[str, list[int]]: return _chunks_to_offsets(self.expanded_chunks) - @cached_property + @functools.cached_property def offset_index(self) -> dict[str, dict[int, int]]: return compute_offset_index(self.offsets) @@ -542,7 +512,78 @@ def _task_count(self) -> int: total += int(np.prod(count_list)) return total - @cached_property + def _key_to_chunks(self, key: Key) -> tuple[Key, DatasetOrDatasets]: + """Convert a Key into an in-memory (Key, xarray.Dataset) pair.""" + with inc_timer_msec(self.__class__, "read-msec"): + sizes = { + dim: self.expanded_chunks[dim][self.offset_index[dim][offset]] + for dim, offset in key.offsets.items() + } + slices = offsets_to_slices(key.offsets, sizes) + results = [] + for ds in self._datasets: + dataset = ds if key.vars is None else ds[list(key.vars)] + valid_slices = {k: v for k, v in slices.items() if k in dataset.dims} + chunk = dataset.isel(valid_slices) + # Load the data, using a separate thread for each variable + num_threads = len(dataset) + result = chunk.chunk().compute(num_workers=num_threads) + results.append(result) + + inc_counter(self.__class__, "read-chunks") + inc_counter( + self.__class__, "read-bytes", sum(result.nbytes for result in results) + ) + + if isinstance(self.dataset, xarray.Dataset): + return key, results[0] + else: + return key, results + + +@export +class DatasetToChunks(_DatasetToChunksBase): + """Split one or more xarray.Datasets into keyed chunks.""" + + def __init__( + self, + dataset: DatasetOrDatasets, + chunks: Mapping[str, int | tuple[int, ...]] | None = None, + split_vars: bool = False, + num_threads: int | None = None, + shard_keys_threshold: int = 200_000, + tasks_per_shard: int = 10_000, + ): + """Initialize DatasetToChunks. + + Args: + dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key, + [xarray.Dataset, ...]) pairs. + chunks: optional chunking scheme. Required if the dataset is *not* already + chunked. If the dataset *is* already chunked with Dask, `chunks` takes + precedence over the existing chunks. + split_vars: whether to split the dataset into separate records for each + data variable or to keep all data variables together. This is + recommended if you don't need to perform joint operations on different + dataset variables and individual variable chunks are sufficiently large. + num_threads: optional number of Dataset chunks to load in parallel per + worker. More threads can increase throughput, but also increases memory + usage and makes it harder for Beam runners to shard work. Note that each + variable in a Dataset is already loaded in parallel, so this is most + useful for Datasets with a small number of variables or when using + split_vars=True. + shard_keys_threshold: threshold at which to compute keys on Beam workers, + rather than only on the host process. This is important for scaling + pipelines to millions of tasks. + tasks_per_shard: number of tasks to emit per shard. Only used if the + number of tasks exceeds shard_keys_threshold. + """ + super().__init__(dataset, chunks, split_vars) + self.num_threads = num_threads + self.shard_keys_threshold = shard_keys_threshold + self.tasks_per_shard = tasks_per_shard + + @functools.cached_property def sharded_dim(self) -> str | None: # We use the simple heuristic of only sharding inputs along the dimension # with the most chunks. @@ -552,7 +593,7 @@ def sharded_dim(self) -> str | None: } return max(lengths, key=lengths.get) if lengths else None # pytype: disable=bad-return-type - @cached_property + @functools.cached_property def shard_count(self) -> int | None: """Determine the number of times to shard input keys.""" task_count = self._task_count() @@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]: inputs.append((None, name)) return inputs # pytype: disable=bad-return-type # always-use-property-annotation - def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]: - """Convert a Key into an in-memory (Key, xarray.Dataset) pair.""" - with inc_timer_msec(self.__class__, "read-msec"): - sizes = { - dim: self.expanded_chunks[dim][self.offset_index[dim][offset]] - for dim, offset in key.offsets.items() - } - slices = offsets_to_slices(key.offsets, sizes) - results = [] - for ds in self._datasets: - dataset = ds if key.vars is None else ds[list(key.vars)] - valid_slices = {k: v for k, v in slices.items() if k in dataset.dims} - chunk = dataset.isel(valid_slices) - # Load the data, using a separate thread for each variable - num_threads = len(dataset) - result = chunk.chunk().compute(num_workers=num_threads) - results.append(result) - - inc_counter(self.__class__, "read-chunks") - inc_counter( - self.__class__, "read-bytes", sum(result.nbytes for result in results) - ) - - if isinstance(self.dataset, xarray.Dataset): - yield key, results[0] - else: - yield key, results - def expand(self, pcoll): if self.shard_count is None: # Create all keys on the machine launching the Beam pipeline. This is @@ -652,11 +665,102 @@ def expand(self, pcoll): | beam.Reshuffle() ) - return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap( + return key_pcoll | "KeyToChunks" >> threadmap.ThreadMap( self._key_to_chunks, num_threads=self.num_threads ) +# TODO(shoyer): expose this function as a public API, after switching it to +# generate Key objects using `indices` instead of `offsets`. +class ReadDataset(_DatasetToChunksBase): + """Read chunks from an xarray.Dataset into a Beam pipeline. + + This PTransform is a Beam "splittable DoFn", which means that it may be + dynamically split by Beam runners into smaller chunks for efficient parallel + execution. + """ + + def __init__( + self, + dataset: xarray.Dataset, + chunks: Mapping[str, int | tuple[int, ...]] | None = None, + split_vars: bool = False, + ): + """Initialize ReadDatasets. + + Args: + dataset: dataset to split into (Key, xarray.Dataset) chunks. + chunks: optional chunking scheme. Required if the dataset is *not* already + chunked. If the dataset *is* already chunked with Dask, `chunks` takes + precedence over the existing chunks. + split_vars: whether to split the dataset into separate records for each + data variable or to keep all data variables together. This is + recommended if you don't need to perform joint operations on different + dataset variables and individual variable chunks are sufficiently large. + """ + super().__init__(dataset, chunks, split_vars) + + @functools.cached_property + def _chunk_index_shapes( + self, + ) -> list[tuple[str | None, tuple[str, ...], tuple[int, ...]]]: + """Calculate the shapes of indices for each chunk of the data. + + The result here is a list of tuples of the form (name, dims, shape), where + name is the name of the variable (or None if all variables are consolidated) + and dims and shape are the dimensions along which the variable's chunk is + indexed, and shape of that chunk in _indices_. For example, if the dataset + had a variable `foo` with dimensions `('x', 'y')`, shape (10, 10) with + chunks `{'x': 5, 'y': 2}`, then this function would return a corresponding + list entry `('foo', ('x', 'y'), (2, 5))`. + """ + out = [] + if not self.split_vars: + dims = sorted(self.expanded_chunks) + shape = tuple(len(self.expanded_chunks[dim]) for dim in dims) + out.append((None, dims, shape)) + else: + for name, variable in self._first.items(): + dims = tuple(d for d in variable.dims if d in self.expanded_chunks) + shape = tuple(len(self.expanded_chunks[dim]) for dim in dims) + out.append((name, dims, shape)) + return out # pytype: disable=bad-return-type + + @functools.cached_property + def _cumulative_sizes(self) -> np.ndarray: + var_sizes = [math.prod(shape) for _, _, shape in self._chunk_index_shapes] + return np.cumsum([0] + var_sizes) + + def _index_to_key(self, position: int) -> Key: + assert 0 <= position < self._cumulative_sizes[-1] + var_index = ( + np.searchsorted(self._cumulative_sizes, position, side="right") - 1 + ) + offset = position - self._cumulative_sizes[var_index] + name, dims, shape = self._chunk_index_shapes[var_index] + indices = np.unravel_index(offset, shape) + offsets = {dim: self.offsets[dim][idx] for dim, idx in zip(dims, indices)} + return Key(offsets, vars=None if name is None else {name}) + + def _get_element(self, position: int) -> tuple[Key, xarray.Dataset]: + return self._key_to_chunks(self._index_to_key(position)) # pytype: disable=bad-return-type + + def expand( + self, pbegin: beam.PBegin + ) -> beam.PCollection[tuple[Key, xarray.Dataset]]: + element_count = self._task_count() + assert element_count > 0 + # For simplicity, assume that all chunks are approximately the same size, + # even if variables are being split and some variables have different + # variables. This assumption could be relaxed in the future, with an + # improved version of RangeSource. + avg_chunk_bytes = math.ceil(self._first.nbytes / element_count) + source = range_source.RangeSource( + element_count, avg_chunk_bytes, self._get_element + ) + return pbegin | beam.io.Read(source) + + def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None: """Ensure that a dataset contains no chunked variables.""" for var_name, variable in dataset.variables.items(): diff --git a/xarray_beam/_src/core_test.py b/xarray_beam/_src/core_test.py index 7028761..f44f6ac 100644 --- a/xarray_beam/_src/core_test.py +++ b/xarray_beam/_src/core_test.py @@ -13,13 +13,13 @@ # limitations under the License. """Tests for xarray_beam._src.core.""" +import pickle import re from absl.testing import absltest from absl.testing import parameterized import apache_beam as beam import dask.array as da import immutabledict -import pickle import numpy as np import xarray import xarray_beam as xbeam @@ -252,9 +252,7 @@ def test_vars_as_beam_key(self): self.assertEqual(actual, expected) def test_pickle(self): - key = xbeam.Key( - {'x': 0, 'y': 10}, vars={'foo'} - ) + key = xbeam.Key({'x': 0, 'y': 10}, vars={'foo'}) unpickled = pickle.loads(pickle.dumps(key)) self.assertEqual(key, unpickled) @@ -655,6 +653,75 @@ def test_validate(self): self.fail('should allow a pipeline where the first has more dimensions.') +class ReadDatasetTest(test_util.TestCase): + + def test_basics(self): + dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) + expected = [ + (xbeam.Key({'x': 0}), dataset.head(x=3)), + (xbeam.Key({'x': 3}), dataset.tail(x=3)), + ] + actual = test_util.EagerPipeline() | core.ReadDataset( + dataset.chunk({'x': 3}) + ) + self.assertIdenticalChunks(actual, expected) + + actual = test_util.EagerPipeline() | core.ReadDataset( + dataset, chunks={'x': 3} + ) + self.assertIdenticalChunks(actual, expected) + + def test_whole_dataset(self): + dataset = xarray.Dataset({'foo': ('x', np.arange(6))}) + expected = [(xbeam.Key({'x': 0}), dataset)] + actual = test_util.EagerPipeline() | core.ReadDataset( + dataset, chunks={'x': -1} + ) + self.assertIdenticalChunks(actual, expected) + + def test_different_vars(self): + dataset = xarray.Dataset({ + 'foo': ('x', np.arange(6)), + 'bar': ('x', -np.arange(6)), + }) + expected = [ + (xbeam.Key({'x': 0}, {'foo'}), dataset.head(x=3)[['foo']]), + (xbeam.Key({'x': 0}, {'bar'}), dataset.head(x=3)[['bar']]), + (xbeam.Key({'x': 3}, {'foo'}), dataset.tail(x=3)[['foo']]), + (xbeam.Key({'x': 3}, {'bar'}), dataset.tail(x=3)[['bar']]), + ] + actual = test_util.EagerPipeline() | core.ReadDataset( + dataset, chunks={'x': 3}, split_vars=True + ) + self.assertIdenticalChunks(actual, expected) + + def test_split_with_different_dims(self): + dataset = xarray.Dataset({ + 'foo': (('x', 'y'), np.array([[1, 2, 3], [4, 5, 6]])), + 'bar': ('x', np.array([1, 2])), + 'baz': ('z', np.array([1, 2, 3])), + }) + expected = [ + (xbeam.Key({'x': 0, 'y': 0}, {'foo'}), dataset[['foo']].head(x=1)), + (xbeam.Key({'x': 0}, {'bar'}), dataset[['bar']].head(x=1)), + (xbeam.Key({'x': 1, 'y': 0}, {'foo'}), dataset[['foo']].tail(x=1)), + (xbeam.Key({'x': 1}, {'bar'}), dataset[['bar']].tail(x=1)), + (xbeam.Key({'z': 0}, {'baz'}), dataset[['baz']]), + ] + actual = test_util.EagerPipeline() | core.ReadDataset( + dataset, + chunks={'x': 1}, + split_vars=True, + ) + self.assertIdenticalChunks(actual, expected) + + def test_read_datasets_empty(self): + dataset = xarray.Dataset() + expected = [(xbeam.Key({}), dataset)] + actual = test_util.EagerPipeline() | core.ReadDataset(dataset, chunks={}) + self.assertIdenticalChunks(actual, expected) + + class ValidateEachChunkTest(test_util.TestCase): def test_validate_chunk_raises_on_dask_chunked(self): diff --git a/xarray_beam/_src/range_source.py b/xarray_beam/_src/range_source.py index 20ad3e9..82519eb 100644 --- a/xarray_beam/_src/range_source.py +++ b/xarray_beam/_src/range_source.py @@ -16,11 +16,10 @@ import dataclasses import math -from typing import Any, Callable, Generic, Iterator, TypeVar +from typing import Callable, Generic, Iterator, TypeVar import apache_beam as beam from apache_beam.io import iobase -from apache_beam.io import range_trackers _T = TypeVar('_T') @@ -36,8 +35,8 @@ class RangeSource(iobase.BoundedSource, Generic[_T]): Attributes: element_count: number of elements in this source. element_size: size of each element in bytes. - get_element: callable that given an integer index in the range - ``[0, element_count)`` returns the corresponding element of the source. + get_element: callable that given an integer index in the range ``[0, + element_count)`` returns the corresponding element of the source. """ element_count: int @@ -50,8 +49,10 @@ def __post_init__(self): raise ValueError( f'element_count must be non-negative: {self.element_count}' ) - if self.element_size <= 0: - raise ValueError(f'element_size must be positive: {self.element_size}') + if self.element_size < 0: + raise ValueError( + f'element_size must be non-negative: {self.element_size}' + ) def estimate_size(self) -> int: """Estimates the size of source in bytes.""" @@ -68,7 +69,7 @@ def split( stop = stop_position if stop_position is not None else self.element_count bundle_size_in_elements = int( - math.ceil(desired_bundle_size / self.element_size) + math.ceil(desired_bundle_size / max(self.element_size, 1)) ) for bundle_start in range(start, stop, bundle_size_in_elements): bundle_stop = min(bundle_start + bundle_size_in_elements, stop) @@ -79,14 +80,14 @@ def get_range_tracker( self, start_position: int | None, stop_position: int | None, - ) -> range_trackers.OffsetRangeTracker: + ) -> beam.io.OffsetRangeTracker: """Returns a RangeTracker for a given position range.""" start = start_position if start_position is not None else 0 stop = stop_position if stop_position is not None else self.element_count - return range_trackers.OffsetRangeTracker(start, stop) + return beam.io.OffsetRangeTracker(start, stop) def read( - self, range_tracker: range_trackers.OffsetRangeTracker + self, range_tracker: beam.io.OffsetRangeTracker ) -> Iterator[_T]: """Returns an iterator that reads data from the source.""" i = range_tracker.start_position()