1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING , Hashable , Iterable , Sequence
3
+ from typing import TYPE_CHECKING , Any , Hashable , Iterable , Sequence , Union
4
4
5
5
import numpy as np
6
6
import pandas as pd
19
19
from .xrutils import _contains_cftime_datetimes , _to_pytimedelta , datetime_to_numeric
20
20
21
21
if TYPE_CHECKING :
22
- from xarray import DataArray , Dataset , Resample
22
+ from xarray .core .resample import Resample
23
+ from xarray .core .types import T_DataArray , T_Dataset
24
+
25
+ Dims = Union [str , Iterable [Hashable ], None ]
23
26
24
27
25
28
def _get_input_core_dims (group_names , dim , ds , grouper_dims ):
@@ -51,13 +54,13 @@ def lookup_order(dimension):
51
54
52
55
53
56
def xarray_reduce (
54
- obj : Dataset | DataArray ,
55
- * by : DataArray | Iterable [ str ] | Iterable [ DataArray ] ,
57
+ obj : T_Dataset | T_DataArray ,
58
+ * by : T_DataArray | Hashable ,
56
59
func : str | Aggregation ,
57
60
expected_groups = None ,
58
61
isbin : bool | Sequence [bool ] = False ,
59
62
sort : bool = True ,
60
- dim : Hashable = None ,
63
+ dim : Dims | ellipsis = None ,
61
64
split_out : int = 1 ,
62
65
fill_value = None ,
63
66
method : str = "map-reduce" ,
@@ -203,8 +206,11 @@ def xarray_reduce(
203
206
if keep_attrs is None :
204
207
keep_attrs = True
205
208
206
- if isinstance (isbin , bool ):
207
- isbin = (isbin ,) * nby
209
+ if isinstance (isbin , Sequence ):
210
+ isbins = isbin
211
+ else :
212
+ isbins = (isbin ,) * nby
213
+
208
214
if expected_groups is None :
209
215
expected_groups = (None ,) * nby
210
216
if isinstance (expected_groups , (np .ndarray , list )): # TODO: test for list
@@ -217,78 +223,86 @@ def xarray_reduce(
217
223
raise NotImplementedError
218
224
219
225
# eventually drop the variables we are grouping by
220
- maybe_drop = [b for b in by if isinstance (b , str )]
226
+ maybe_drop = [b for b in by if isinstance (b , Hashable )]
221
227
unindexed_dims = tuple (
222
228
b
223
- for b , isbin_ in zip (by , isbin )
224
- if isinstance (b , str ) and not isbin_ and b in obj .dims and b not in obj .indexes
229
+ for b , isbin_ in zip (by , isbins )
230
+ if isinstance (b , Hashable ) and not isbin_ and b in obj .dims and b not in obj .indexes
225
231
)
226
232
227
- by : tuple [ DataArray ] = tuple (obj [g ] if isinstance (g , str ) else g for g in by ) # type: ignore
233
+ by_da = tuple (obj [g ] if isinstance (g , Hashable ) else g for g in by )
228
234
229
235
grouper_dims = []
230
- for g in by :
236
+ for g in by_da :
231
237
for d in g .dims :
232
238
if d not in grouper_dims :
233
239
grouper_dims .append (d )
234
240
235
- if isinstance (obj , xr .DataArray ):
236
- ds = obj ._to_temp_dataset ()
237
- else :
241
+ if isinstance (obj , xr .Dataset ):
238
242
ds = obj
243
+ else :
244
+ ds = obj ._to_temp_dataset ()
239
245
240
246
ds = ds .drop_vars ([var for var in maybe_drop if var in ds .variables ])
241
247
242
248
if dim is Ellipsis :
243
249
if nby > 1 :
244
250
raise NotImplementedError ("Multiple by are not allowed when dim is Ellipsis." )
245
- dim = tuple (obj .dims )
246
- if by [0 ].name in ds .dims and not isbin [0 ]:
247
- dim = tuple (d for d in dim if d != by [0 ].name )
251
+ name_ = by_da [0 ].name
252
+ if name_ in ds .dims and not isbins [0 ]:
253
+ dim_tuple = tuple (d for d in obj .dims if d != name_ )
254
+ else :
255
+ dim_tuple = tuple (obj .dims )
248
256
elif dim is not None :
249
- dim = _atleast_1d (dim )
257
+ dim_tuple = _atleast_1d (dim )
250
258
else :
251
- dim = tuple ()
259
+ dim_tuple = tuple ()
252
260
253
261
# broadcast all variables against each other along all dimensions in `by` variables
254
262
# don't exclude `dim` because it need not be a dimension in any of the `by` variables!
255
263
# in the case where dim is Ellipsis, and by.ndim < obj.ndim
256
264
# then we also broadcast `by` to all `obj.dims`
257
265
# TODO: avoid this broadcasting
258
- exclude_dims = tuple (d for d in ds .dims if d not in grouper_dims and d not in dim )
259
- ds , * by = xr .broadcast (ds , * by , exclude = exclude_dims )
266
+ exclude_dims = tuple (d for d in ds .dims if d not in grouper_dims and d not in dim_tuple )
267
+ ds_broad , * by_broad = xr .broadcast (ds , * by_da , exclude = exclude_dims )
260
268
261
- if not dim :
262
- dim = tuple (by [0 ].dims )
269
+ # all members of by_broad have the same dimensions
270
+ # so we just pull by_broad[0].dims if dim is None
271
+ if not dim_tuple :
272
+ dim_tuple = tuple (by_broad [0 ].dims )
263
273
264
- if any (d not in grouper_dims and d not in obj .dims for d in dim ):
274
+ if any (d not in grouper_dims and d not in obj .dims for d in dim_tuple ):
265
275
raise ValueError (f"Cannot reduce over absent dimensions { dim } ." )
266
276
267
- dims_not_in_groupers = tuple (d for d in dim if d not in grouper_dims )
268
- if dims_not_in_groupers == tuple (dim ) and not any (isbin ):
277
+ dims_not_in_groupers = tuple (d for d in dim_tuple if d not in grouper_dims )
278
+ if dims_not_in_groupers == tuple (dim_tuple ) and not any (isbins ):
269
279
# reducing along a dimension along which groups do not vary
270
280
# This is really just a normal reduction.
271
281
# This is not right when binning so we exclude.
272
- if skipna and isinstance (func , str ):
273
- dsfunc = func [3 :]
282
+ if isinstance (func , str ):
283
+ dsfunc = func [3 :] if skipna else func
274
284
else :
275
- dsfunc = func
285
+ raise NotImplementedError (
286
+ "func must be a string when reducing along a dimension not present in `by`"
287
+ )
276
288
# TODO: skipna needs test
277
- result = getattr (ds , dsfunc )(dim = dim , skipna = skipna )
289
+ result = getattr (ds_broad , dsfunc )(dim = dim_tuple , skipna = skipna )
278
290
if isinstance (obj , xr .DataArray ):
279
291
return obj ._from_temp_dataset (result )
280
292
else :
281
293
return result
282
294
283
- axis = tuple (range (- len (dim ), 0 ))
284
- group_names = tuple (g .name if not binned else f"{ g .name } _bins" for g , binned in zip (by , isbin ))
285
-
286
- group_shape = [None ] * len (by )
287
- expected_groups = list (expected_groups )
295
+ axis = tuple (range (- len (dim_tuple ), 0 ))
288
296
289
297
# Set expected_groups and convert to index since we need coords, sizes
290
298
# for output xarray objects
291
- for idx , (b , expect , isbin_ ) in enumerate (zip (by , expected_groups , isbin )):
299
+ expected_groups = list (expected_groups )
300
+ group_names : tuple [Any , ...] = ()
301
+ group_sizes : dict [Any , int ] = {}
302
+ for idx , (b_ , expect , isbin_ ) in enumerate (zip (by_broad , expected_groups , isbins )):
303
+ group_name = b_ .name if not isbin_ else f"{ b_ .name } _bins"
304
+ group_names += (group_name ,)
305
+
292
306
if isbin_ and isinstance (expect , int ):
293
307
raise NotImplementedError (
294
308
"flox does not support binning into an integer number of bins yet."
@@ -297,13 +311,21 @@ def xarray_reduce(
297
311
if isbin_ :
298
312
raise ValueError (
299
313
f"Please provided bin edges for group variable { idx } "
300
- f"named { group_names [ idx ] } in expected_groups."
314
+ f"named { group_name } in expected_groups."
301
315
)
302
- expected_groups [idx ] = _get_expected_groups (b .data , sort = sort , raise_if_dask = True )
303
-
304
- expected_groups = _convert_expected_groups_to_index (expected_groups , isbin , sort = sort )
305
- group_shape = tuple (len (e ) for e in expected_groups )
306
- group_sizes = dict (zip (group_names , group_shape ))
316
+ expect_ = _get_expected_groups (b_ .data , sort = sort , raise_if_dask = True )
317
+ else :
318
+ expect_ = expect
319
+ expect_index = _convert_expected_groups_to_index ((expect_ ,), (isbin_ ,), sort = sort )[0 ]
320
+
321
+ # The if-check is for type hinting mainly, it narrows down the return
322
+ # type of _convert_expected_groups_to_index to pure pd.Index:
323
+ if expect_index is not None :
324
+ expected_groups [idx ] = expect_index
325
+ group_sizes [group_name ] = len (expect_index )
326
+ else :
327
+ # This will never be reached
328
+ raise ValueError ("expect_index cannot be None" )
307
329
308
330
def wrapper (array , * by , func , skipna , ** kwargs ):
309
331
# Handle skipna here because I need to know dtype to make a good default choice.
@@ -349,20 +371,20 @@ def wrapper(array, *by, func, skipna, **kwargs):
349
371
if isinstance (obj , xr .Dataset ):
350
372
# broadcasting means the group dim gets added to ds, so we check the original obj
351
373
for k , v in obj .data_vars .items ():
352
- is_missing_dim = not (any (d in v .dims for d in dim ))
374
+ is_missing_dim = not (any (d in v .dims for d in dim_tuple ))
353
375
if is_missing_dim :
354
376
missing_dim [k ] = v
355
377
356
- input_core_dims = _get_input_core_dims (group_names , dim , ds , grouper_dims )
378
+ input_core_dims = _get_input_core_dims (group_names , dim_tuple , ds_broad , grouper_dims )
357
379
input_core_dims += [input_core_dims [- 1 ]] * (nby - 1 )
358
380
359
381
actual = xr .apply_ufunc (
360
382
wrapper ,
361
- ds .drop_vars (tuple (missing_dim )).transpose (..., * grouper_dims ),
362
- * by ,
383
+ ds_broad .drop_vars (tuple (missing_dim )).transpose (..., * grouper_dims ),
384
+ * by_broad ,
363
385
input_core_dims = input_core_dims ,
364
386
# for xarray's test_groupby_duplicate_coordinate_labels
365
- exclude_dims = set (dim ),
387
+ exclude_dims = set (dim_tuple ),
366
388
output_core_dims = [group_names ],
367
389
dask = "allowed" ,
368
390
dask_gufunc_kwargs = dict (output_sizes = group_sizes ),
@@ -379,27 +401,27 @@ def wrapper(array, *by, func, skipna, **kwargs):
379
401
"engine" : engine ,
380
402
"reindex" : reindex ,
381
403
"expected_groups" : tuple (expected_groups ),
382
- "isbin" : isbin ,
404
+ "isbin" : isbins ,
383
405
"finalize_kwargs" : finalize_kwargs ,
384
406
},
385
407
)
386
408
387
409
# restore non-dim coord variables without the core dimension
388
410
# TODO: shouldn't apply_ufunc handle this?
389
- for var in set (ds .variables ) - set (ds .dims ):
390
- if all (d not in ds [var ].dims for d in dim ):
391
- actual [var ] = ds [var ]
411
+ for var in set (ds_broad .variables ) - set (ds_broad .dims ):
412
+ if all (d not in ds_broad [var ].dims for d in dim_tuple ):
413
+ actual [var ] = ds_broad [var ]
392
414
393
- for name , expect , by_ in zip (group_names , expected_groups , by ):
415
+ for name , expect , by_ in zip (group_names , expected_groups , by_broad ):
394
416
# Can't remove this till xarray handles IntervalIndex
395
417
if isinstance (expect , pd .IntervalIndex ):
396
418
expect = expect .to_numpy ()
397
419
if isinstance (actual , xr .Dataset ) and name in actual :
398
420
actual = actual .drop_vars (name )
399
421
# When grouping by MultiIndex, expect is an pd.Index wrapping
400
422
# an object array of tuples
401
- if name in ds .indexes and isinstance (ds .indexes [name ], pd .MultiIndex ):
402
- levelnames = ds .indexes [name ].names
423
+ if name in ds_broad .indexes and isinstance (ds_broad .indexes [name ], pd .MultiIndex ):
424
+ levelnames = ds_broad .indexes [name ].names
403
425
expect = pd .MultiIndex .from_tuples (expect .values , names = levelnames )
404
426
actual [name ] = expect
405
427
if Version (xr .__version__ ) > Version ("2022.03.0" ):
@@ -414,18 +436,17 @@ def wrapper(array, *by, func, skipna, **kwargs):
414
436
415
437
if nby == 1 :
416
438
for var in actual :
417
- if isinstance (obj , xr .DataArray ):
418
- template = obj
419
- else :
439
+ if isinstance (obj , xr .Dataset ):
420
440
template = obj [var ]
441
+ else :
442
+ template = obj
443
+
421
444
if actual [var ].ndim > 1 :
422
- actual [var ] = _restore_dim_order (actual [var ], template , by [0 ])
445
+ actual [var ] = _restore_dim_order (actual [var ], template , by_broad [0 ])
423
446
424
447
if missing_dim :
425
448
for k , v in missing_dim .items ():
426
- missing_group_dims = {
427
- dim : size for dim , size in group_sizes .items () if dim not in v .dims
428
- }
449
+ missing_group_dims = {d : size for d , size in group_sizes .items () if d not in v .dims }
429
450
# The expand_dims is for backward compat with xarray's questionable behaviour
430
451
if missing_group_dims :
431
452
actual [k ] = v .expand_dims (missing_group_dims ).variable
@@ -439,9 +460,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
439
460
440
461
441
462
def rechunk_for_cohorts (
442
- obj : DataArray | Dataset ,
463
+ obj : T_DataArray | T_Dataset ,
443
464
dim : str ,
444
- labels : DataArray ,
465
+ labels : T_DataArray ,
445
466
force_new_chunk_at ,
446
467
chunksize : int | None = None ,
447
468
ignore_old_chunks : bool = False ,
@@ -486,7 +507,7 @@ def rechunk_for_cohorts(
486
507
)
487
508
488
509
489
- def rechunk_for_blockwise (obj : DataArray | Dataset , dim : str , labels : DataArray ):
510
+ def rechunk_for_blockwise (obj : T_DataArray | T_Dataset , dim : str , labels : T_DataArray ):
490
511
"""
491
512
Rechunks array so that group boundaries line up with chunk boundaries, allowing
492
513
embarassingly parallel group reductions.
0 commit comments