2727import immutabledict
2828import numpy as np
2929import xarray
30+ from xarray_beam ._src import range_source
3031from xarray_beam ._src import threadmap
3132
3233
33- T = TypeVar ('T' )
34+ T = TypeVar ("T" )
3435
3536
3637def export (obj : T ) -> T :
37- obj .__module__ = ' xarray_beam'
38+ obj .__module__ = " xarray_beam"
3839 return obj
3940
4041
@@ -122,7 +123,6 @@ class Key:
122123 Key(indices={'x': 4}, vars={'bar'})
123124 >>> key.with_indices(x=5)
124125 Key(indices={'x': 5}, vars={'bar'})
125-
126126 """
127127
128128 # pylint: disable=redefined-builtin
@@ -184,8 +184,8 @@ def with_indices(self, **indices: int | None) -> Key:
184184 """Replace some indices with new values.
185185
186186 Args:
187- **indices: indices to override (for integer values) or remove, with
188- values of ``None``.
187+ **indices: indices to override (for integer values) or remove, with values
188+ of ``None``.
189189
190190 Returns:
191191 New Key with the specified indices.
@@ -421,49 +421,19 @@ def normalize_expanded_chunks(
421421)
422422
423423
424- @export
425- class DatasetToChunks (beam .PTransform , Generic [DatasetOrDatasets ]):
426- """Split one or more xarray.Datasets into keyed chunks."""
424+ class _DatasetToChunksBase (beam .PTransform , Generic [DatasetOrDatasets ]):
425+ """Base class for PTransforms that split Datasets into chunks."""
427426
428427 def __init__ (
429428 self ,
430429 dataset : DatasetOrDatasets ,
431430 chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
432431 split_vars : bool = False ,
433- num_threads : int | None = None ,
434- shard_keys_threshold : int = 200_000 ,
435- tasks_per_shard : int = 10_000 ,
436432 ):
437- """Initialize DatasetToChunks.
438-
439- Args:
440- dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
441- [xarray.Dataset, ...]) pairs.
442- chunks: optional chunking scheme. Required if the dataset is *not* already
443- chunked. If the dataset *is* already chunked with Dask, `chunks` takes
444- precedence over the existing chunks.
445- split_vars: whether to split the dataset into separate records for each
446- data variable or to keep all data variables together. This is
447- recommended if you don't need to perform joint operations on different
448- dataset variables and individual variable chunks are sufficiently large.
449- num_threads: optional number of Dataset chunks to load in parallel per
450- worker. More threads can increase throughput, but also increases memory
451- usage and makes it harder for Beam runners to shard work. Note that each
452- variable in a Dataset is already loaded in parallel, so this is most
453- useful for Datasets with a small number of variables or when using
454- split_vars=True.
455- shard_keys_threshold: threshold at which to compute keys on Beam workers,
456- rather than only on the host process. This is important for scaling
457- pipelines to millions of tasks.
458- tasks_per_shard: number of tasks to emit per shard. Only used if the
459- number of tasks exceeds shard_keys_threshold.
460- """
433+ """Initialize _DatasetToChunksBase."""
461434 self .dataset = dataset
462435 self ._validate (dataset , split_vars )
463436 self .split_vars = split_vars
464- self .num_threads = num_threads
465- self .shard_keys_threshold = shard_keys_threshold
466- self .tasks_per_shard = tasks_per_shard
467437
468438 if chunks is None :
469439 dask_chunks = self ._first .chunks
@@ -542,6 +512,77 @@ def _task_count(self) -> int:
542512 total += int (np .prod (count_list ))
543513 return total
544514
515+ def _key_to_chunks (self , key : Key ) -> tuple [Key , DatasetOrDatasets ]:
516+ """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
517+ with inc_timer_msec (self .__class__ , "read-msec" ):
518+ sizes = {
519+ dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
520+ for dim , offset in key .offsets .items ()
521+ }
522+ slices = offsets_to_slices (key .offsets , sizes )
523+ results = []
524+ for ds in self ._datasets :
525+ dataset = ds if key .vars is None else ds [list (key .vars )]
526+ valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
527+ chunk = dataset .isel (valid_slices )
528+ # Load the data, using a separate thread for each variable
529+ num_threads = len (dataset )
530+ result = chunk .chunk ().compute (num_workers = num_threads )
531+ results .append (result )
532+
533+ inc_counter (self .__class__ , "read-chunks" )
534+ inc_counter (
535+ self .__class__ , "read-bytes" , sum (result .nbytes for result in results )
536+ )
537+
538+ if isinstance (self .dataset , xarray .Dataset ):
539+ return key , results [0 ]
540+ else :
541+ return key , results
542+
543+
544+ @export
545+ class DatasetToChunks (_DatasetToChunksBase ):
546+ """Split one or more xarray.Datasets into keyed chunks."""
547+
548+ def __init__ (
549+ self ,
550+ dataset : DatasetOrDatasets ,
551+ chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
552+ split_vars : bool = False ,
553+ num_threads : int | None = None ,
554+ shard_keys_threshold : int = 200_000 ,
555+ tasks_per_shard : int = 10_000 ,
556+ ):
557+ """Initialize DatasetToChunks.
558+
559+ Args:
560+ dataset: dataset or datasets to split into (Key, xarray.Dataset) or (Key,
561+ [xarray.Dataset, ...]) pairs.
562+ chunks: optional chunking scheme. Required if the dataset is *not* already
563+ chunked. If the dataset *is* already chunked with Dask, `chunks` takes
564+ precedence over the existing chunks.
565+ split_vars: whether to split the dataset into separate records for each
566+ data variable or to keep all data variables together. This is
567+ recommended if you don't need to perform joint operations on different
568+ dataset variables and individual variable chunks are sufficiently large.
569+ num_threads: optional number of Dataset chunks to load in parallel per
570+ worker. More threads can increase throughput, but also increases memory
571+ usage and makes it harder for Beam runners to shard work. Note that each
572+ variable in a Dataset is already loaded in parallel, so this is most
573+ useful for Datasets with a small number of variables or when using
574+ split_vars=True.
575+ shard_keys_threshold: threshold at which to compute keys on Beam workers,
576+ rather than only on the host process. This is important for scaling
577+ pipelines to millions of tasks.
578+ tasks_per_shard: number of tasks to emit per shard. Only used if the
579+ number of tasks exceeds shard_keys_threshold.
580+ """
581+ super ().__init__ (dataset , chunks , split_vars )
582+ self .num_threads = num_threads
583+ self .shard_keys_threshold = shard_keys_threshold
584+ self .tasks_per_shard = tasks_per_shard
585+
545586 @cached_property
546587 def sharded_dim (self ) -> str | None :
547588 # We use the simple heuristic of only sharding inputs along the dimension
@@ -610,34 +651,6 @@ def _shard_inputs(self) -> list[tuple[int | None, str | None]]:
610651 inputs .append ((None , name ))
611652 return inputs # pytype: disable=bad-return-type # always-use-property-annotation
612653
613- def _key_to_chunks (self , key : Key ) -> Iterator [tuple [Key , DatasetOrDatasets ]]:
614- """Convert a Key into an in-memory (Key, xarray.Dataset) pair."""
615- with inc_timer_msec (self .__class__ , "read-msec" ):
616- sizes = {
617- dim : self .expanded_chunks [dim ][self .offset_index [dim ][offset ]]
618- for dim , offset in key .offsets .items ()
619- }
620- slices = offsets_to_slices (key .offsets , sizes )
621- results = []
622- for ds in self ._datasets :
623- dataset = ds if key .vars is None else ds [list (key .vars )]
624- valid_slices = {k : v for k , v in slices .items () if k in dataset .dims }
625- chunk = dataset .isel (valid_slices )
626- # Load the data, using a separate thread for each variable
627- num_threads = len (dataset )
628- result = chunk .chunk ().compute (num_workers = num_threads )
629- results .append (result )
630-
631- inc_counter (self .__class__ , "read-chunks" )
632- inc_counter (
633- self .__class__ , "read-bytes" , sum (result .nbytes for result in results )
634- )
635-
636- if isinstance (self .dataset , xarray .Dataset ):
637- yield key , results [0 ]
638- else :
639- yield key , results
640-
641654 def expand (self , pcoll ):
642655 if self .shard_count is None :
643656 # Create all keys on the machine launching the Beam pipeline. This is
@@ -652,11 +665,95 @@ def expand(self, pcoll):
652665 | beam .Reshuffle ()
653666 )
654667
655- return key_pcoll | "KeyToChunks" >> threadmap .FlatThreadMap (
668+ return key_pcoll | "KeyToChunks" >> threadmap .ThreadMap (
656669 self ._key_to_chunks , num_threads = self .num_threads
657670 )
658671
659672
673+ # TODO(shoyer): expose this function as a public API, after switching it to
674+ # generate Key objects using `indices` instead of `offsets`.
675+ class ReadDataset (_DatasetToChunksBase ):
676+ """Read chunks from an xarray.Dataset into a Beam pipeline.
677+
678+ This PTransform is a Beam "splittable DoFn", which means that it may be
679+ dynamically split by Beam runners into smaller chunks for efficient parallel
680+ execution.
681+ """
682+
683+ def __init__ (
684+ self ,
685+ dataset : xarray .Dataset ,
686+ chunks : Mapping [str , int | tuple [int , ...]] | None = None ,
687+ split_vars : bool = False ,
688+ ):
689+ """Initialize ReadDatasets.
690+
691+ Args:
692+ dataset: dataset to split into (Key, xarray.Dataset) chunks.
693+ chunks: optional chunking scheme. Required if the dataset is *not* already
694+ chunked. If the dataset *is* already chunked with Dask, `chunks` takes
695+ precedence over the existing chunks.
696+ split_vars: whether to split the dataset into separate records for each
697+ data variable or to keep all data variables together. This is
698+ recommended if you don't need to perform joint operations on different
699+ dataset variables and individual variable chunks are sufficiently large.
700+ """
701+ super ().__init__ (dataset , chunks , split_vars )
702+
703+ @cached_property
704+ def _var_chunk_counts (
705+ self ,
706+ ) -> list [tuple [str | None , list [str ], tuple [int , ...]]]:
707+ out = []
708+ if not self .split_vars :
709+ dims = sorted (self .expanded_chunks )
710+ shape = tuple (len (self .expanded_chunks [dim ]) for dim in dims )
711+ out .append ((None , dims , shape ))
712+ else :
713+ for name , variable in self ._first .items ():
714+ dims = sorted ([d for d in variable .dims if d in self .expanded_chunks ])
715+ shape = tuple (len (self .expanded_chunks [dim ]) for dim in dims )
716+ out .append ((name , dims , shape ))
717+ return out # pytype: disable=bad-return-type
718+
719+ @cached_property
720+ def _var_sizes (self ) -> list [int ]:
721+ return [int (np .prod (shape )) for _ , _ , shape in self ._var_chunk_counts ]
722+
723+ @cached_property
724+ def _cumulative_sizes (self ) -> np .ndarray :
725+ return np .cumsum ([0 ] + self ._var_sizes )
726+
727+ def _index_to_key (self , position : int ) -> Key :
728+ assert 0 <= position < self ._cumulative_sizes [- 1 ]
729+ var_index = (
730+ np .searchsorted (self ._cumulative_sizes , position , side = "right" ) - 1
731+ )
732+ offset = position - self ._cumulative_sizes [var_index ]
733+ name , dims , shape = self ._var_chunk_counts [var_index ]
734+ indices = np .unravel_index (offset , shape )
735+ offsets = {dim : self .offsets [dim ][idx ] for dim , idx in zip (dims , indices )}
736+ return Key (offsets , vars = None if name is None else {name })
737+
738+ def _get_element (self , position : int ) -> tuple [Key , xarray .Dataset ]:
739+ return self ._key_to_chunks (self ._index_to_key (position )) # pytype: disable=bad-return-type
740+
741+ def expand (
742+ self , pbegin : beam .PBegin
743+ ) -> beam .PCollection [tuple [Key , xarray .Dataset ]]:
744+ element_count = self ._task_count ()
745+ assert element_count > 0
746+ # For simplicity, assume that all chunks are approximately the same size,
747+ # even if variables are being split and some variables have different
748+ # variables. This assumption could be relaxed in the future, with an
749+ # improved version of RangeSource.
750+ avg_chunk_bytes = math .ceil (self ._first .nbytes / element_count )
751+ source = range_source .RangeSource (
752+ element_count , avg_chunk_bytes , self ._get_element
753+ )
754+ return pbegin | beam .io .Read (source )
755+
756+
660757def _ensure_chunk_is_computed (key : Key , dataset : xarray .Dataset ) -> None :
661758 """Ensure that a dataset contains no chunked variables."""
662759 for var_name , variable in dataset .variables .items ():
0 commit comments