@@ -55,7 +55,7 @@ def dask_array_ones(*args):
55
55
"nansum" ,
56
56
"argmax" ,
57
57
"nanfirst" ,
58
- pytest . param ( "nanargmax" , marks = ( pytest . mark . skip ,)) ,
58
+ "nanargmax" ,
59
59
"prod" ,
60
60
"nanprod" ,
61
61
"mean" ,
@@ -69,7 +69,7 @@ def dask_array_ones(*args):
69
69
"min" ,
70
70
"nanmin" ,
71
71
"argmin" ,
72
- pytest . param ( "nanargmin" , marks = ( pytest . mark . skip ,)) ,
72
+ "nanargmin" ,
73
73
"any" ,
74
74
"all" ,
75
75
"nanlast" ,
@@ -233,8 +233,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
233
233
# computing silences a bunch of dask warnings
234
234
array_ = array .compute () if chunks is not None else array
235
235
if "arg" in func and add_nan_by :
236
+ # NaNs are in by, but we can't call np.argmax([..., NaN, .. ])
237
+ # That would return index of the NaN
238
+ # This way, we insert NaNs where there are NaNs in by, and
239
+ # call np.nanargmax
240
+ func_ = f"nan{ func } " if "nan" not in func else func
236
241
array_ [..., nanmask ] = np .nan
237
- expected = getattr (np , "nan" + func )(array_ , axis = - 1 , ** kwargs )
242
+ expected = getattr (np , func_ )(array_ , axis = - 1 , ** kwargs )
238
243
# elif func in ["first", "last"]:
239
244
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240
245
elif func in ["nanfirst" , "nanlast" ]:
@@ -259,6 +264,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
259
264
260
265
params = list (itertools .product (["map-reduce" ], [True , False , None ]))
261
266
params .extend (itertools .product (["cohorts" ], [False , None ]))
267
+ if chunks == - 1 :
268
+ params .extend ([("blockwise" , None )])
269
+
262
270
for method , reindex in params :
263
271
call = partial (
264
272
groupby_reduce , array , * by , method = method , reindex = reindex , ** flox_kwargs
@@ -269,11 +277,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
269
277
call ()
270
278
continue
271
279
actual , * groups = call ()
272
- if "arg" not in func :
273
- # make sure we use simple combine
274
- assert any ("simple-combine" in key for key in actual .dask .layers .keys ())
275
- else :
276
- assert any ("grouped-combine" in key for key in actual .dask .layers .keys ())
280
+ if method != "blockwise" :
281
+ if "arg" not in func :
282
+ # make sure we use simple combine
283
+ assert any ("simple-combine" in key for key in actual .dask .layers .keys ())
284
+ else :
285
+ assert any ("grouped-combine" in key for key in actual .dask .layers .keys ())
277
286
for actual_group , expect in zip (groups , expected_groups ):
278
287
assert_equal (actual_group , expect , tolerance )
279
288
if "arg" in func :
0 commit comments