3
3
import itertools
4
4
import warnings
5
5
from functools import partial , reduce
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING , Callable
7
7
8
8
import numpy as np
9
9
import pandas as pd
10
10
import pytest
11
11
from numpy_groupies .aggregate_numpy import aggregate
12
12
13
+ from flox import xrutils
13
14
from flox .aggregations import Aggregation
14
15
from flox .core import (
15
16
_convert_expected_groups_to_index ,
@@ -53,6 +54,7 @@ def dask_array_ones(*args):
53
54
"sum" ,
54
55
"nansum" ,
55
56
"argmax" ,
57
+ "nanfirst" ,
56
58
pytest .param ("nanargmax" , marks = (pytest .mark .skip ,)),
57
59
"prod" ,
58
60
"nanprod" ,
@@ -70,6 +72,7 @@ def dask_array_ones(*args):
70
72
pytest .param ("nanargmin" , marks = (pytest .mark .skip ,)),
71
73
"any" ,
72
74
"all" ,
75
+ "nanlast" ,
73
76
pytest .param ("median" , marks = (pytest .mark .skip ,)),
74
77
pytest .param ("nanmedian" , marks = (pytest .mark .skip ,)),
75
78
)
@@ -78,6 +81,21 @@ def dask_array_ones(*args):
78
81
from flox .core import T_Engine , T_ExpectedGroupsOpt , T_Func2
79
82
80
83
84
+ def _get_array_func (func : str ) -> Callable :
85
+ if func == "count" :
86
+
87
+ def npfunc (x ):
88
+ x = np .asarray (x )
89
+ return (~ np .isnan (x )).sum ()
90
+
91
+ elif func in ["nanfirst" , "nanlast" ]:
92
+ npfunc = getattr (xrutils , func )
93
+ else :
94
+ npfunc = getattr (np , func )
95
+
96
+ return npfunc
97
+
98
+
81
99
def test_alignment_error ():
82
100
da = np .ones ((12 ,))
83
101
labels = np .ones ((5 ,))
@@ -217,6 +235,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
217
235
if "arg" in func and add_nan_by :
218
236
array_ [..., nanmask ] = np .nan
219
237
expected = getattr (np , "nan" + func )(array_ , axis = - 1 , ** kwargs )
238
+ # elif func in ["first", "last"]:
239
+ # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240
+ elif func in ["nanfirst" , "nanlast" ]:
241
+ expected = getattr (xrutils , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
220
242
else :
221
243
expected = getattr (np , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
222
244
for _ in range (nby ):
@@ -241,7 +263,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
241
263
call = partial (
242
264
groupby_reduce , array , * by , method = method , reindex = reindex , ** flox_kwargs
243
265
)
244
- if "arg" in func and reindex is True :
266
+ if ( "arg" in func or func in [ "first" , "last" ]) and reindex is True :
245
267
# simple_combine with argreductions not supported right now
246
268
with pytest .raises (NotImplementedError ):
247
269
call ()
@@ -486,6 +508,28 @@ def test_dask_reduce_axis_subset():
486
508
)
487
509
488
510
511
+ @pytest .mark .parametrize ("func" , ["first" , "last" , "nanfirst" , "nanlast" ])
512
+ @pytest .mark .parametrize ("axis" , [(0 , 1 )])
513
+ def test_first_last_disallowed (axis , func ):
514
+ with pytest .raises (ValueError ):
515
+ groupby_reduce (np .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
516
+
517
+
518
+ @requires_dask
519
+ @pytest .mark .parametrize ("func" , ["nanfirst" , "nanlast" ])
520
+ @pytest .mark .parametrize ("axis" , [None , (0 , 1 , 2 )])
521
+ def test_nanfirst_nanlast_disallowed_dask (axis , func ):
522
+ with pytest .raises (ValueError ):
523
+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
524
+
525
+
526
+ @requires_dask
527
+ @pytest .mark .parametrize ("func" , ["first" , "last" ])
528
+ def test_first_last_disallowed_dask (func ):
529
+ with pytest .raises (NotImplementedError ):
530
+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = - 1 )
531
+
532
+
489
533
@requires_dask
490
534
@pytest .mark .parametrize ("func" , ALL_FUNCS )
491
535
@pytest .mark .parametrize (
@@ -495,8 +539,34 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
495
539
if "arg" in func and engine == "flox" :
496
540
pytest .skip ()
497
541
498
- if not isinstance (axis , int ) and "arg" in func and (axis is None or len (axis ) > 1 ):
499
- pytest .skip ()
542
+ if not isinstance (axis , int ):
543
+ if "arg" in func and (axis is None or len (axis ) > 1 ):
544
+ pytest .skip ()
545
+ if ("first" in func or "last" in func ) and (axis is not None and len (axis ) not in [1 , 3 ]):
546
+ pytest .skip ()
547
+
548
+ if func in ["all" , "any" ]:
549
+ fill_value = False
550
+ else :
551
+ fill_value = 123
552
+
553
+ if "var" in func or "std" in func :
554
+ tolerance = {"rtol" : 1e-14 , "atol" : 1e-16 }
555
+ else :
556
+ tolerance = None
557
+ # tests against the numpy output to make sure dask compute matches
558
+ by = np .broadcast_to (labels2d , (3 , * labels2d .shape ))
559
+ rng = np .random .default_rng (12345 )
560
+ array = rng .random (by .shape )
561
+ kwargs = dict (
562
+ func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
563
+ )
564
+ expected , _ = groupby_reduce (array , by , ** kwargs )
565
+ if engine == "flox" :
566
+ kwargs .pop ("engine" )
567
+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
568
+ assert_equal (expected_npg , expected )
569
+
500
570
if func in ["all" , "any" ]:
501
571
fill_value = False
502
572
else :
@@ -513,17 +583,23 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
513
583
kwargs = dict (
514
584
func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
515
585
)
586
+ expected , _ = groupby_reduce (array , by , ** kwargs )
587
+ if engine == "flox" :
588
+ kwargs .pop ("engine" )
589
+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
590
+ assert_equal (expected_npg , expected )
591
+
592
+ if ("first" in func or "last" in func ) and (
593
+ axis is None or (not isinstance (axis , int ) and len (axis ) != 1 )
594
+ ):
595
+ return
596
+
516
597
with raise_if_dask_computes ():
517
598
actual , _ = groupby_reduce (
518
599
da .from_array (array , chunks = (- 1 , 2 , 3 )),
519
600
da .from_array (by , chunks = (- 1 , 2 , 2 )),
520
601
** kwargs ,
521
602
)
522
- expected , _ = groupby_reduce (array , by , ** kwargs )
523
- if engine == "flox" :
524
- kwargs .pop ("engine" )
525
- expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
526
- assert_equal (expected_npg , expected )
527
603
assert_equal (actual , expected , tolerance )
528
604
529
605
@@ -751,23 +827,17 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
751
827
if chunks is not None and not has_dask :
752
828
pytest .skip ()
753
829
754
- if func == "count" :
755
-
756
- def npfunc (x ):
757
- x = np .asarray (x )
758
- return (~ np .isnan (x )).sum ()
759
-
760
- else :
761
- npfunc = getattr (np , func )
762
-
830
+ npfunc = _get_array_func (func )
763
831
by = np .array ([1 , 2 , 3 , 1 , 2 , 3 ])
764
832
array = np .array ([np .nan , 1 , 1 , np .nan , 1 , 1 ])
765
833
if chunks :
766
834
array = dask .array .from_array (array , chunks )
767
835
actual , _ = groupby_reduce (
768
836
array , by , func = func , engine = engine , fill_value = fill_value , expected_groups = [0 , 1 , 2 , 3 ]
769
837
)
770
- expected = np .array ([fill_value , fill_value , npfunc ([1.0 , 1.0 ]), npfunc ([1.0 , 1.0 ])])
838
+ expected = np .array (
839
+ [fill_value , fill_value , npfunc ([1.0 , 1.0 ], axis = 0 ), npfunc ([1.0 , 1.0 ], axis = 0 )]
840
+ )
771
841
assert_equal (actual , expected )
772
842
773
843
@@ -832,6 +902,8 @@ def test_cohorts_nd_by(func, method, axis, engine):
832
902
833
903
if axis is not None and method != "map-reduce" :
834
904
pytest .xfail ()
905
+ if axis is None and ("first" in func or "last" in func ):
906
+ pytest .skip ()
835
907
836
908
kwargs = dict (func = func , engine = engine , method = method , axis = axis , fill_value = fill_value )
837
909
actual , groups = groupby_reduce (array , by , ** kwargs )
@@ -897,7 +969,8 @@ def test_bool_reductions(func, engine):
897
969
pytest .skip ()
898
970
groups = np .array ([1 , 1 , 1 ])
899
971
data = np .array ([True , True , False ])
900
- expected = np .expand_dims (getattr (np , func )(data ), - 1 )
972
+ npfunc = _get_array_func (func )
973
+ expected = np .expand_dims (npfunc (data , axis = 0 ), - 1 )
901
974
actual , _ = groupby_reduce (data , groups , func = func , engine = engine )
902
975
assert_equal (expected , actual )
903
976
0 commit comments