Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 177 additions & 73 deletions xarray_beam/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections.abc import Hashable, Iterator, Mapping, Sequence, Set
import contextlib
from functools import cached_property
import functools
import itertools
import math
import pickle
Expand All @@ -27,14 +27,15 @@
import immutabledict
import numpy as np
import xarray
from xarray_beam._src import range_source
from xarray_beam._src import threadmap


T = TypeVar('T')
T = TypeVar("T")


def export(obj: T) -> T:
obj.__module__ = 'xarray_beam'
obj.__module__ = "xarray_beam"
return obj


Expand Down Expand Up @@ -122,7 +123,6 @@ class Key:
Key(indices={'x': 4}, vars={'bar'})
>>> key.with_indices(x=5)
Key(indices={'x': 5}, vars={'bar'})

"""

# pylint: disable=redefined-builtin
Expand Down Expand Up @@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key:
"""Replace some indices with new values.

Args:
**indices: indices to override (for integer values) or remove, with
values of ``None``.
**indices: indices to override (for integer values) or remove, with values
of ``None``.

Returns:
New Key with the specified indices.
Expand Down Expand Up @@ -421,49 +421,19 @@ def normalize_expanded_chunks(
)


@export
class DatasetToChunks(beam.PTransform, Generic[DatasetOrDatasets]):
"""Split one or more xarray.Datasets into keyed chunks."""
class _DatasetToChunksBase(beam.PTransform, Generic[DatasetOrDatasets]):
"""Base class for PTransforms that split Datasets into chunks."""

def __init__(
self,
dataset: DatasetOrDatasets,
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
split_vars: bool = False,
num_threads: int | None = None,
shard_keys_threshold: int = 200_000,
tasks_per_shard: int = 10_000,
):
"""Initialize DatasetToChunks.

Args:
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
[xarray.Dataset, ...]) pairs.
chunks: optional chunking scheme. Required if the dataset is *not* already
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variable or to keep all data variables together. This is
recommended if you don't need to perform joint operations on different
dataset variables and individual variable chunks are sufficiently large.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
usage and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already loaded in parallel, so this is most
useful for Datasets with a small number of variables or when using
split_vars=True.
shard_keys_threshold: threshold at which to compute keys on Beam workers,
rather than only on the host process. This is important for scaling
pipelines to millions of tasks.
tasks_per_shard: number of tasks to emit per shard. Only used if the
number of tasks exceeds shard_keys_threshold.
"""
"""Initialize _DatasetToChunksBase."""
self.dataset = dataset
self._validate(dataset, split_vars)
self.split_vars = split_vars
self.num_threads = num_threads
self.shard_keys_threshold = shard_keys_threshold
self.tasks_per_shard = tasks_per_shard

if chunks is None:
dask_chunks = self._first.chunks
Expand All @@ -489,15 +459,15 @@ def _datasets(self) -> list[xarray.Dataset]:
return [self.dataset]
return list(self.dataset) # pytype: disable=bad-return-type

@cached_property
@functools.cached_property
def expanded_chunks(self) -> dict[str, tuple[int, ...]]:
return normalize_expanded_chunks(self.chunks, self._first.sizes) # pytype: disable=wrong-arg-types # always-use-property-annotation

@cached_property
@functools.cached_property
def offsets(self) -> dict[str, list[int]]:
return _chunks_to_offsets(self.expanded_chunks)

@cached_property
@functools.cached_property
def offset_index(self) -> dict[str, dict[int, int]]:
return compute_offset_index(self.offsets)

Expand Down Expand Up @@ -542,7 +512,78 @@ def _task_count(self) -> int:
total += int(np.prod(count_list))
return total

@cached_property
def _key_to_chunks(self, key: Key) -> tuple[Key, DatasetOrDatasets]:
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
with inc_timer_msec(self.__class__, "read-msec"):
sizes = {
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
for dim, offset in key.offsets.items()
}
slices = offsets_to_slices(key.offsets, sizes)
results = []
for ds in self._datasets:
dataset = ds if key.vars is None else ds[list(key.vars)]
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
chunk = dataset.isel(valid_slices)
# Load the data, using a separate thread for each variable
num_threads = len(dataset)
result = chunk.chunk().compute(num_workers=num_threads)
results.append(result)

inc_counter(self.__class__, "read-chunks")
inc_counter(
self.__class__, "read-bytes", sum(result.nbytes for result in results)
)

if isinstance(self.dataset, xarray.Dataset):
return key, results[0]
else:
return key, results


@export
class DatasetToChunks(_DatasetToChunksBase):
"""Split one or more xarray.Datasets into keyed chunks."""

def __init__(
self,
dataset: DatasetOrDatasets,
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
split_vars: bool = False,
num_threads: int | None = None,
shard_keys_threshold: int = 200_000,
tasks_per_shard: int = 10_000,
):
"""Initialize DatasetToChunks.

Args:
dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
[xarray.Dataset, ...]) pairs.
chunks: optional chunking scheme. Required if the dataset is *not* already
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variable or to keep all data variables together. This is
recommended if you don't need to perform joint operations on different
dataset variables and individual variable chunks are sufficiently large.
num_threads: optional number of Dataset chunks to load in parallel per
worker. More threads can increase throughput, but also increases memory
usage and makes it harder for Beam runners to shard work. Note that each
variable in a Dataset is already loaded in parallel, so this is most
useful for Datasets with a small number of variables or when using
split_vars=True.
shard_keys_threshold: threshold at which to compute keys on Beam workers,
rather than only on the host process. This is important for scaling
pipelines to millions of tasks.
tasks_per_shard: number of tasks to emit per shard. Only used if the
number of tasks exceeds shard_keys_threshold.
"""
super().__init__(dataset, chunks, split_vars)
self.num_threads = num_threads
self.shard_keys_threshold = shard_keys_threshold
self.tasks_per_shard = tasks_per_shard

@functools.cached_property
def sharded_dim(self) -> str | None:
# We use the simple heuristic of only sharding inputs along the dimension
# with the most chunks.
Expand All @@ -552,7 +593,7 @@ def sharded_dim(self) -> str | None:
}
return max(lengths, key=lengths.get) if lengths else None # pytype: disable=bad-return-type

@cached_property
@functools.cached_property
def shard_count(self) -> int | None:
"""Determine the number of times to shard input keys."""
task_count = self._task_count()
Expand Down Expand Up @@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
inputs.append((None, name))
return inputs # pytype: disable=bad-return-type # always-use-property-annotation

def _key_to_chunks(self, key: Key) -> Iterator[tuple[Key, DatasetOrDatasets]]:
"""Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
with inc_timer_msec(self.__class__, "read-msec"):
sizes = {
dim: self.expanded_chunks[dim][self.offset_index[dim][offset]]
for dim, offset in key.offsets.items()
}
slices = offsets_to_slices(key.offsets, sizes)
results = []
for ds in self._datasets:
dataset = ds if key.vars is None else ds[list(key.vars)]
valid_slices = {k: v for k, v in slices.items() if k in dataset.dims}
chunk = dataset.isel(valid_slices)
# Load the data, using a separate thread for each variable
num_threads = len(dataset)
result = chunk.chunk().compute(num_workers=num_threads)
results.append(result)

inc_counter(self.__class__, "read-chunks")
inc_counter(
self.__class__, "read-bytes", sum(result.nbytes for result in results)
)

if isinstance(self.dataset, xarray.Dataset):
yield key, results[0]
else:
yield key, results

def expand(self, pcoll):
if self.shard_count is None:
# Create all keys on the machine launching the Beam pipeline. This is
Expand All @@ -652,11 +665,102 @@ def expand(self, pcoll):
| beam.Reshuffle()
)

return key_pcoll | "KeyToChunks" >> threadmap.FlatThreadMap(
return key_pcoll | "KeyToChunks" >> threadmap.ThreadMap(
self._key_to_chunks, num_threads=self.num_threads
)


# TODO(shoyer): expose this function as a public API, after switching it to
# generate Key objects using `indices` instead of `offsets`.
class ReadDataset(_DatasetToChunksBase):
"""Read chunks from an xarray.Dataset into a Beam pipeline.

This PTransform is a Beam "splittable DoFn", which means that it may be
dynamically split by Beam runners into smaller chunks for efficient parallel
execution.
"""

def __init__(
self,
dataset: xarray.Dataset,
chunks: Mapping[str, int | tuple[int, ...]] | None = None,
split_vars: bool = False,
):
"""Initialize ReadDatasets.

Args:
dataset: dataset to split into (Key, xarray.Dataset) chunks.
chunks: optional chunking scheme. Required if the dataset is *not* already
chunked. If the dataset *is* already chunked with Dask, `chunks` takes
precedence over the existing chunks.
split_vars: whether to split the dataset into separate records for each
data variable or to keep all data variables together. This is
recommended if you don't need to perform joint operations on different
dataset variables and individual variable chunks are sufficiently large.
"""
super().__init__(dataset, chunks, split_vars)

@functools.cached_property
def _chunk_index_shapes(
self,
) -> list[tuple[str | None, tuple[str, ...], tuple[int, ...]]]:
"""Calculate the shapes of indices for each chunk of the data.

The result here is a list of tuples of the form (name, dims, shape), where
name is the name of the variable (or None if all variables are consolidated)
and dims and shape are the dimensions along which the variable's chunk is
indexed, and shape of that chunk in _indices_. For example, if the dataset
had a variable `foo` with dimensions `('x', 'y')`, shape (10, 10) with
chunks `{'x': 5, 'y': 2}`, then this function would return a corresponding
list entry `('foo', ('x', 'y'), (2, 5))`.
"""
out = []
if not self.split_vars:
dims = sorted(self.expanded_chunks)
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
out.append((None, dims, shape))
else:
for name, variable in self._first.items():
dims = tuple(d for d in variable.dims if d in self.expanded_chunks)
shape = tuple(len(self.expanded_chunks[dim]) for dim in dims)
out.append((name, dims, shape))
return out # pytype: disable=bad-return-type

@functools.cached_property
def _cumulative_sizes(self) -> np.ndarray:
var_sizes = [math.prod(shape) for _, _, shape in self._chunk_index_shapes]
return np.cumsum([0] + var_sizes)

def _index_to_key(self, position: int) -> Key:
assert 0 <= position < self._cumulative_sizes[-1]
var_index = (
np.searchsorted(self._cumulative_sizes, position, side="right") - 1
)
offset = position - self._cumulative_sizes[var_index]
name, dims, shape = self._chunk_index_shapes[var_index]
indices = np.unravel_index(offset, shape)
offsets = {dim: self.offsets[dim][idx] for dim, idx in zip(dims, indices)}
return Key(offsets, vars=None if name is None else {name})

def _get_element(self, position: int) -> tuple[Key, xarray.Dataset]:
return self._key_to_chunks(self._index_to_key(position)) # pytype: disable=bad-return-type

def expand(
self, pbegin: beam.PBegin
) -> beam.PCollection[tuple[Key, xarray.Dataset]]:
element_count = self._task_count()
assert element_count > 0
# For simplicity, assume that all chunks are approximately the same size,
# even if variables are being split and some variables have different
# variables. This assumption could be relaxed in the future, with an
# improved version of RangeSource.
avg_chunk_bytes = math.ceil(self._first.nbytes / element_count)
source = range_source.RangeSource(
element_count, avg_chunk_bytes, self._get_element
)
return pbegin | beam.io.Read(source)


def _ensure_chunk_is_computed(key: Key, dataset: xarray.Dataset) -> None:
"""Ensure that a dataset contains no chunked variables."""
for var_name, variable in dataset.variables.items():
Expand Down
Loading