@@ -227,6 +227,98 @@ def _infer_new_chunks(
227227 return new_chunks
228228
229229
230+ def _normalize_and_validate_chunk (
231+ template : xarray .Dataset ,
232+ chunks : Mapping [str , int ],
233+ split_vars : bool ,
234+ key : core .Key ,
235+ dataset : xarray .Dataset ,
236+ ) -> tuple [core .Key , xarray .Dataset ]:
237+ """Validate and normalize (key, dataset) pairs for a Dataset."""
238+
239+ if split_vars :
240+ if key .vars is None :
241+ key = key .replace (vars = set (dataset .keys ()))
242+ elif key .vars != set (dataset .keys ()):
243+ raise ValueError (
244+ f'dataset keys { sorted (dataset .keys ())} do not match'
245+ f' key.vars={ sorted (key .vars )} '
246+ )
247+ elif key .vars is not None :
248+ raise ValueError (f'must not set vars on key if split_vars=False: { key } ' )
249+
250+ new_offsets = dict (key .offsets )
251+ for dim in dataset .dims :
252+ if dim not in new_offsets :
253+ new_offsets [dim ] = 0
254+ if len (new_offsets ) != len (key .offsets ):
255+ key = key .replace (offsets = new_offsets )
256+
257+ core ._ensure_chunk_is_computed (key , dataset )
258+
259+ def _with_dataset (msg : str ):
260+ dataset_repr = textwrap .indent (repr (dataset ), prefix = ' ' )
261+ return f'{ msg } \n Key: { key } \n Dataset chunk:\n { dataset_repr } '
262+
263+ def _bad_template_error (msg : str ):
264+ template_repr = textwrap .indent (repr (template ), prefix = ' ' )
265+ raise ValueError (_with_dataset (msg ) + f'Template:\n { template_repr } ' )
266+
267+ for k , v in dataset .items ():
268+ if k not in template :
269+ _bad_template_error (
270+ f'Chunk variable { k !r} not found in template variables '
271+ f' { list (template .data_vars )} :'
272+ )
273+ if v .dtype != template [k ].dtype :
274+ _bad_template_error (
275+ f'Chunk variable { k !r} has dtype { v .dtype } which does not match'
276+ f' template variable dtype { template [k ].dtype } :'
277+ )
278+ if v .dims != template [k ].dims :
279+ _bad_template_error (
280+ f'Chunk variable { k !r} has dims { v .dims } which does not match'
281+ f' template variable dims { template [k ].dims } :'
282+ )
283+
284+ for dim , size in dataset .sizes .items ():
285+ if dim not in chunks :
286+ raise ValueError (
287+ _with_dataset (
288+ f'Dataset dimension { dim !r} not found in chunks { chunks } :'
289+ )
290+ )
291+ offset = key .offsets [dim ]
292+ if offset % chunks [dim ] != 0 :
293+ raise ValueError (
294+ _with_dataset (
295+ f'Chunk offset { offset } is not aligned with chunk '
296+ f'size { chunks [dim ]} for dimension { dim !r} :'
297+ )
298+ )
299+ if offset + size > template .sizes [dim ]:
300+ _bad_template_error (
301+ f'Chunk dimension { dim !r} has size { size } which is larger than the '
302+ f'remaining size { template .sizes [dim ] - offset } in the '
303+ 'template:'
304+ )
305+ is_last_chunk = offset + chunks [dim ] > template .sizes [dim ]
306+ if is_last_chunk :
307+ expected_size = template .sizes [dim ] - offset
308+ if size != expected_size :
309+ _bad_template_error (
310+ f'Chunk dimension { dim !r} is the last chunk, but has size { size } '
311+ f'which does not match expected size { expected_size } :'
312+ )
313+ elif size != chunks [dim ]:
314+ _bad_template_error (
315+ f'Chunk dimension { dim !r} has size { size } which does not match'
316+ f' chunk size { chunks [dim ]} :'
317+ )
318+
319+ return key , dataset
320+
321+
230322def _apply_to_each_chunk (
231323 func : Callable [[xarray .Dataset ], xarray .Dataset ],
232324 old_chunks : Mapping [str , int ],
@@ -302,9 +394,8 @@ def __init__(
302394 ):
303395 """Low level interface for creating a new Dataset, without validation.
304396
305- Most users should use the higher level
306- :py:class:`xarray_beam.Dataset.from_xarray` or
307- :py:class:`xarray_beam.Dataset.from_zarr` instead.
397+ Unless you're really sure you don't need validation, prefer using
398+ :py:class:`xarray_beam.Dataset.from_ptransform`.
308399
309400 Args:
310401 template: xarray.Dataset describing the structure of this dataset,
@@ -317,9 +408,7 @@ def __init__(
317408 this dataset's data.
318409 """
319410 self ._template = template
320- self ._chunks = {
321- k : min (template .sizes [k ], v ) for k , v in chunks .items ()
322- }
411+ self ._chunks = {k : min (template .sizes [k ], v ) for k , v in chunks .items ()}
323412 self ._split_vars = split_vars
324413 self ._ptransform = ptransform
325414
@@ -390,6 +479,62 @@ def __repr__(self):
390479 + textwrap .indent ('\n ' .join (base .split ('\n ' )[1 :]), ' ' * 4 )
391480 )
392481
482+ @classmethod
483+ def from_ptransform (
484+ cls ,
485+ ptransform : beam .PTransform ,
486+ * ,
487+ template : xarray .Dataset ,
488+ chunks : Mapping [str | types .EllipsisType , int ],
489+ split_vars : bool = False ,
490+ ) -> Dataset :
491+ """Create an xarray_beam.Dataset from a Beam PTransform.
492+
493+ This is an advanced constructor that allows you to create an
494+ ``xarray_beam.Dataset`` from an existing Beam PTransform that produces
495+ ``(Key, xarray.Dataset)`` pairs.
496+
497+ The PTransform should produce chunks that conform to the given ``template``,
498+ ``chunks``, and ``split_vars`` arguments. This constructor will add a
499+ validation step to the PTransform to normalize keys into the strictest
500+ possible form based on the other arguments, and ensure that transform
501+ outputs are valid.
502+
503+ Args:
504+ ptransform: A Beam PTransform that yields ``(Key, xarray.Dataset)`` pairs.
505+ You only need to set ``offsets`` on these keys, ``vars`` will be
506+ automatically set based on the dataset if ``split_vars`` is True.
507+ template: An ``xarray.Dataset`` object representing the schema
508+ (coordinates, dimensions, data variables, and attributes) of the full
509+ dataset, as produced by :py:func:`xarray_beam.make_template`, with data
510+ variables backed by Dask arrays.
511+ chunks: A dictionary mapping dimension names to integer chunk sizes. Every
512+ chunk produced by ``ptransform`` must have dimensions of these sizes,
513+ except for the last chunk in each dimension, which may be smaller.
514+ split_vars: A boolean indicating whether the chunks in ``ptransform`` are
515+ split across variables, or if each chunk contains all variables.
516+
517+ Returns:
518+ An ``xarray_beam.Dataset`` instance wrapping the PTransform.
519+ """
520+ if not isinstance (chunks , Mapping ):
521+ raise TypeError (
522+ f'chunks must be a mapping for from_ptransform, got { chunks } '
523+ )
524+ for v in chunks .values ():
525+ if not isinstance (v , int ):
526+ raise TypeError (
527+ 'chunks must be a mapping with integer values for from_ptransform,'
528+ f' got { chunks } '
529+ )
530+ chunks = normalize_chunks (chunks , template )
531+ ptransform = ptransform | _get_label ("validate" ) >> beam .MapTuple (
532+ functools .partial (
533+ _normalize_and_validate_chunk , template , chunks , split_vars
534+ )
535+ )
536+ return cls (template , chunks , split_vars , ptransform )
537+
393538 @classmethod
394539 def from_xarray (
395540 cls ,
0 commit comments