@@ -509,6 +509,7 @@ def chunk_argreduce(
509
509
dask.array.reductions.argtopk
510
510
"""
511
511
array , idx = array_plus_idx
512
+ by = np .broadcast_to (by , array .shape )
512
513
513
514
results = chunk_reduce (
514
515
array ,
@@ -522,17 +523,22 @@ def chunk_argreduce(
522
523
sort = sort ,
523
524
)
524
525
if not isnull (results ["groups" ]).all ():
525
- # will not work for empty groups...
526
- # glorious
527
526
idx = np .broadcast_to (idx , array .shape )
527
+
528
+ # array, by get flattened to 1D before passing to npg
529
+ # so the indexes need to be unraveled
528
530
newidx = np .unravel_index (results ["intermediates" ][1 ], array .shape )
531
+
532
+ # Now index into the actual "global" indexes `idx`
529
533
results ["intermediates" ][1 ] = idx [newidx ]
530
534
531
535
if reindex and expected_groups is not None :
532
536
results ["intermediates" ][1 ] = reindex_ (
533
537
results ["intermediates" ][1 ], results ["groups" ].squeeze (), expected_groups , fill_value = 0
534
538
)
535
539
540
+ assert results ["intermediates" ][0 ].shape == results ["intermediates" ][1 ].shape
541
+
536
542
return results
537
543
538
544
@@ -879,34 +885,45 @@ def _grouped_combine(
879
885
array_idx = tuple (
880
886
_conc2 (x_chunk , key1 = "intermediates" , key2 = idx , axis = axis ) for idx in (0 , 1 )
881
887
)
882
- results = chunk_argreduce (
883
- array_idx ,
884
- groups ,
885
- func = agg .combine [slicer ], # count gets treated specially next
886
- axis = axis ,
887
- expected_groups = None ,
888
- fill_value = agg .fill_value ["intermediate" ][slicer ],
889
- dtype = agg .dtype ["intermediate" ][slicer ],
890
- engine = engine ,
891
- sort = sort ,
892
- )
888
+
889
+ # for a single element along axis, we don't want to run the argreduction twice
890
+ # This happens when we are reducing along an axis with a single chunk.
891
+ avoid_reduction = array_idx [0 ].shape [axis [0 ]] == 1
892
+ if avoid_reduction :
893
+ results = {"groups" : groups , "intermediates" : list (array_idx )}
894
+ else :
895
+ results = chunk_argreduce (
896
+ array_idx ,
897
+ groups ,
898
+ func = agg .combine [slicer ], # count gets treated specially next
899
+ axis = axis ,
900
+ expected_groups = None ,
901
+ fill_value = agg .fill_value ["intermediate" ][slicer ],
902
+ dtype = agg .dtype ["intermediate" ][slicer ],
903
+ engine = engine ,
904
+ sort = sort ,
905
+ )
893
906
894
907
if agg .chunk [- 1 ] == "nanlen" :
895
908
counts = _conc2 (x_chunk , key1 = "intermediates" , key2 = 2 , axis = axis )
896
- # sum the counts
897
- results ["intermediates" ].append (
898
- chunk_reduce (
899
- counts ,
900
- groups ,
901
- func = "sum" ,
902
- axis = axis ,
903
- expected_groups = None ,
904
- fill_value = (0 ,),
905
- dtype = (np .intp ,),
906
- engine = engine ,
907
- sort = sort ,
908
- )["intermediates" ][0 ]
909
- )
909
+
910
+ if avoid_reduction :
911
+ results ["intermediates" ].append (counts )
912
+ else :
913
+ # sum the counts
914
+ results ["intermediates" ].append (
915
+ chunk_reduce (
916
+ counts ,
917
+ groups ,
918
+ func = "sum" ,
919
+ axis = axis ,
920
+ expected_groups = None ,
921
+ fill_value = (0 ,),
922
+ dtype = (np .intp ,),
923
+ engine = engine ,
924
+ sort = sort ,
925
+ )["intermediates" ][0 ]
926
+ )
910
927
911
928
elif agg .reduction_type == "reduce" :
912
929
# Here we reduce the intermediates individually
@@ -1006,24 +1023,7 @@ def _reduce_blockwise(array, by, agg, *, axis, expected_groups, fill_value, engi
1006
1023
) # type: ignore
1007
1024
1008
1025
if _is_arg_reduction (agg ):
1009
- if array .ndim > 1 :
1010
- # default fill_value is -1; we can't unravel that;
1011
- # so replace -1 with 0; unravel; then replace 0 with -1
1012
- # UGH!
1013
- idx = results ["intermediates" ][0 ]
1014
- mask = idx == agg .fill_value ["numpy" ][0 ]
1015
- idx [mask ] = 0
1016
- # Fix npg bug where argmax with nD array, 1D group_idx, axis=-1
1017
- # will return wrong indices
1018
- idx = np .unravel_index (idx , array .shape )[- 1 ]
1019
- idx [mask ] = agg .fill_value ["numpy" ][0 ]
1020
- results ["intermediates" ][0 ] = idx
1021
- elif agg .name in ["nanvar" , "nanstd" ]:
1022
- # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1023
- value , counts = results ["intermediates" ]
1024
- mask = counts <= 0
1025
- value [mask ] = np .nan
1026
- results ["intermediates" ][0 ] = value
1026
+ results ["intermediates" ][0 ] = np .unravel_index (results ["intermediates" ][0 ], array .shape )[- 1 ]
1027
1027
1028
1028
result = _finalize_results (
1029
1029
results , agg , axis , expected_groups , fill_value = fill_value , reindex = reindex
@@ -1530,12 +1530,7 @@ def groupby_reduce(
1530
1530
# The only way to do this consistently is mask out using min_count
1531
1531
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
1532
1532
if min_count is None :
1533
- if (
1534
- len (axis ) < by .ndim
1535
- or fill_value is not None
1536
- # TODO: Fix npg bug where all-NaN rows are 0 instead of NaN
1537
- or (not has_dask and isinstance (func , str ) and func in ["nanvar" , "nanstd" ])
1538
- ):
1533
+ if len (axis ) < by .ndim or fill_value is not None :
1539
1534
min_count = 1
1540
1535
1541
1536
# TODO: set in xarray?
0 commit comments