@@ -195,7 +195,6 @@ class FactorizeKwargs(TypedDict, total=False):
195
195
by : T_Bys
196
196
axes : T_Axes
197
197
fastpath : bool
198
- expected_groups : T_ExpectIndexOptTuple | None
199
198
reindex : bool
200
199
sort : bool
201
200
@@ -844,6 +843,67 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
844
843
return offset , size
845
844
846
845
846
+ def _factorize_single (by , expect , * , sort : bool , reindex : bool ):
847
+ flat = by .reshape (- 1 )
848
+ if isinstance (expect , pd .RangeIndex ):
849
+ # idx is a view of the original `by` array
850
+ # copy here so we don't have a race condition with the
851
+ # group_idx[nanmask] = nan_sentinel assignment later
852
+ # this is important in shared-memory parallelism with dask
853
+ # TODO: figure out how to avoid this
854
+ idx = flat .copy ()
855
+ found_groups = np .array (expect )
856
+ # TODO: fix by using masked integers
857
+ idx [idx > expect [- 1 ]] = - 1
858
+
859
+ elif isinstance (expect , pd .IntervalIndex ):
860
+ if expect .closed == "both" :
861
+ raise NotImplementedError
862
+ bins = np .concatenate ([expect .left .to_numpy (), expect .right .to_numpy ()[[- 1 ]]])
863
+
864
+ # digitize is 0 or idx.max() for values outside the bounds of all intervals
865
+ # make it behave like pd.cut which uses -1:
866
+ if len (bins ) > 1 :
867
+ right = expect .closed_right
868
+ idx = np .digitize (
869
+ flat ,
870
+ bins = bins .view (np .int64 ) if bins .dtype .kind == "M" else bins ,
871
+ right = right ,
872
+ )
873
+ idx -= 1
874
+ within_bins = flat <= bins .max () if right else flat < bins .max ()
875
+ idx [~ within_bins ] = - 1
876
+ else :
877
+ idx = np .zeros_like (flat , dtype = np .intp ) - 1
878
+ found_groups = np .array (expect )
879
+ else :
880
+ if expect is not None and reindex :
881
+ sorter = np .argsort (expect )
882
+ groups = expect [(sorter ,)] if sort else expect
883
+ idx = np .searchsorted (expect , flat , sorter = sorter )
884
+ mask = ~ np .isin (flat , expect ) | isnull (flat ) | (idx == len (expect ))
885
+ if not sort :
886
+ # idx is the index in to the sorted array.
887
+ # if we didn't want sorting, unsort it back
888
+ idx [(idx == len (expect ),)] = - 1
889
+ idx = sorter [(idx ,)]
890
+ idx [mask ] = - 1
891
+ else :
892
+ idx , groups = pd .factorize (flat , sort = sort )
893
+ found_groups = np .array (groups )
894
+
895
+ return (found_groups , idx .reshape (by .shape ))
896
+
897
+
898
+ def _ravel_factorized (* factorized : np .ndarray , grp_shape : tuple [int , ...]) -> np .ndarray :
899
+ group_idx = np .ravel_multi_index (factorized , grp_shape , mode = "wrap" )
900
+ # NaNs; as well as values outside the bins are coded by -1
901
+ # Restore these after the raveling
902
+ nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
903
+ group_idx [nan_by_mask ] = - 1
904
+ return group_idx
905
+
906
+
847
907
@overload
848
908
def factorize_ (
849
909
by : T_Bys ,
@@ -890,7 +950,7 @@ def factorize_(
890
950
fastpath : bool = False ,
891
951
) -> tuple [np .ndarray , tuple [np .ndarray , ...], tuple [int , ...], int , int , FactorProps | None ]:
892
952
"""
893
- Returns an array of integer codes for groups (and associated data)
953
+ Returns an array of integer codes for groups (and associated data)
894
954
by wrapping pd.cut and pd.factorize (depending on isbin).
895
955
This method handles reindex and sort so that we don't spend time reindexing / sorting
896
956
a possibly large results array. Instead we set up the appropriate integer codes (group_idx)
@@ -899,75 +959,32 @@ def factorize_(
899
959
if expected_groups is None :
900
960
expected_groups = (None ,) * len (by )
901
961
902
- factorized = []
903
- found_groups = []
904
- for groupvar , expect in zip (by , expected_groups ):
905
- flat = groupvar .reshape (- 1 )
906
- if isinstance (expect , pd .RangeIndex ):
907
- # idx is a view of the original `by` array
908
- # copy here so we don't have a race condition with the
909
- # group_idx[nanmask] = nan_sentinel assignment later
910
- # this is important in shared-memory parallelism with dask
911
- # TODO: figure out how to avoid this
912
- idx = flat .copy ()
913
- found_groups .append (np .array (expect ))
914
- # TODO: fix by using masked integers
915
- idx [idx > expect [- 1 ]] = - 1
916
-
917
- elif isinstance (expect , pd .IntervalIndex ):
918
- if expect .closed == "both" :
919
- raise NotImplementedError
920
- bins = np .concatenate ([expect .left .to_numpy (), expect .right .to_numpy ()[[- 1 ]]])
921
-
922
- # digitize is 0 or idx.max() for values outside the bounds of all intervals
923
- # make it behave like pd.cut which uses -1:
924
- if len (bins ) > 1 :
925
- right = expect .closed_right
926
- idx = np .digitize (
927
- flat ,
928
- bins = bins .view (np .int64 ) if bins .dtype .kind == "M" else bins ,
929
- right = right ,
930
- )
931
- idx -= 1
932
- within_bins = flat <= bins .max () if right else flat < bins .max ()
933
- idx [~ within_bins ] = - 1
934
- else :
935
- idx = np .zeros_like (flat , dtype = np .intp ) - 1
936
-
937
- found_groups .append (np .array (expect ))
938
- else :
939
- if expect is not None and reindex :
940
- sorter = np .argsort (expect )
941
- groups = expect [(sorter ,)] if sort else expect
942
- idx = np .searchsorted (expect , flat , sorter = sorter )
943
- mask = ~ np .isin (flat , expect ) | isnull (flat ) | (idx == len (expect ))
944
- if not sort :
945
- # idx is the index in to the sorted array.
946
- # if we didn't want sorting, unsort it back
947
- idx [(idx == len (expect ),)] = - 1
948
- idx = sorter [(idx ,)]
949
- idx [mask ] = - 1
950
- else :
951
- idx , groups = pd .factorize (flat , sort = sort )
952
-
953
- found_groups .append (np .array (groups ))
954
- factorized .append (idx .reshape (groupvar .shape ))
962
+ if len (by ) > 2 :
963
+ with ThreadPoolExecutor () as executor :
964
+ futures = [
965
+ executor .submit (partial (_factorize_single , sort = sort , reindex = reindex ), groupvar , expect )
966
+ for groupvar , expect in zip (by , expected_groups )
967
+ ]
968
+ results = tuple (f .result () for f in futures )
969
+ else :
970
+ results = tuple (
971
+ _factorize_single (groupvar , expect , sort = sort , reindex = reindex )
972
+ for groupvar , expect in zip (by , expected_groups )
973
+ )
974
+ found_groups = [r [0 ] for r in results ]
975
+ factorized = [r [1 ] for r in results ]
955
976
956
977
grp_shape = tuple (len (grp ) for grp in found_groups )
957
978
ngroups = math .prod (grp_shape )
958
979
if len (by ) > 1 :
959
- group_idx = np .ravel_multi_index (factorized , grp_shape , mode = "wrap" )
960
- # NaNs; as well as values outside the bins are coded by -1
961
- # Restore these after the raveling
962
- nan_by_mask = reduce (np .logical_or , [(f == - 1 ) for f in factorized ])
963
- group_idx [nan_by_mask ] = - 1
980
+ group_idx = _ravel_factorized (* factorized , grp_shape = grp_shape )
964
981
else :
965
- group_idx = factorized [ 0 ]
982
+ ( group_idx ,) = factorized
966
983
967
984
if fastpath :
968
985
return group_idx , tuple (found_groups ), grp_shape , ngroups , ngroups , None
969
986
970
- if len (axes ) == 1 and groupvar .ndim > 1 :
987
+ if len (axes ) == 1 and by [ 0 ] .ndim > 1 :
971
988
# Not reducing along all dimensions of by
972
989
# this is OK because for 3D by and axis=(1,2),
973
990
# we collapse to a 2D by and axis=-1
@@ -2258,7 +2275,6 @@ def _factorize_multiple(
2258
2275
) -> tuple [tuple [np .ndarray ], tuple [np .ndarray , ...], tuple [int , ...]]:
2259
2276
kwargs : FactorizeKwargs = dict (
2260
2277
axes = (), # always (), we offset later if necessary.
2261
- expected_groups = expected_groups ,
2262
2278
fastpath = True ,
2263
2279
# This is the only way it makes sense I think.
2264
2280
# reindex controls what's actually allocated in chunk_reduce
@@ -2272,34 +2288,36 @@ def _factorize_multiple(
2272
2288
# unifying chunks will make sure all arrays in `by` are dask arrays
2273
2289
# with compatible chunks, even if there was originally a numpy array
2274
2290
inds = tuple (range (by [0 ].ndim ))
2275
- chunks , by_ = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
2276
-
2277
- group_idx = dask .array .map_blocks (
2278
- _lazy_factorize_wrapper ,
2279
- * by_ ,
2280
- chunks = tuple (chunks .values ()),
2281
- meta = np .array ((), dtype = np .int64 ),
2282
- ** kwargs ,
2283
- )
2284
-
2285
- fg , gs = [], []
2286
2291
for by_ , expect in zip (by , expected_groups ):
2287
- if expect is None :
2288
- if is_duck_dask_array (by_ ):
2289
- raise ValueError ("Please provide expected_groups when grouping by a dask array." )
2292
+ if expect is None and is_duck_dask_array (by_ ):
2293
+ raise ValueError ("Please provide expected_groups when grouping by a dask array." )
2290
2294
2291
- found_group = pd .unique (by_ .reshape (- 1 ))
2292
- else :
2293
- found_group = expect .to_numpy ()
2295
+ found_groups = tuple (
2296
+ pd .unique (by_ .reshape (- 1 )) if expect is None else expect .to_numpy ()
2297
+ for by_ , expect in zip (by , expected_groups )
2298
+ )
2299
+ grp_shape = tuple (map (len , found_groups ))
2294
2300
2295
- fg .append (found_group )
2296
- gs .append (len (found_group ))
2301
+ chunks , by_chunked = dask .array .unify_chunks (* itertools .chain (* zip (by , (inds ,) * len (by ))))
2302
+ group_idxs = [
2303
+ dask .array .map_blocks (
2304
+ _lazy_factorize_wrapper ,
2305
+ by_ ,
2306
+ expected_groups = (expect_ ,),
2307
+ meta = np .array ((), dtype = np .int64 ),
2308
+ ** kwargs ,
2309
+ )
2310
+ for by_ , expect_ in zip (by_chunked , expected_groups )
2311
+ ]
2312
+ # This could be avoied but we'd use `np.where`
2313
+ # instead `_ravel_factorized` instead i.e. a copy.
2314
+ group_idx = dask .array .map_blocks (
2315
+ _ravel_factorized , * group_idxs , grp_shape = grp_shape , chunks = tuple (chunks .values ()), dtype = np .int64
2316
+ )
2297
2317
2298
- found_groups = tuple (fg )
2299
- grp_shape = tuple (gs )
2300
2318
else :
2301
2319
kwargs ["by" ] = by
2302
- group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs )
2320
+ group_idx , found_groups , grp_shape , * _ = factorize_ (** kwargs , expected_groups = expected_groups )
2303
2321
2304
2322
return (group_idx ,), found_groups , grp_shape
2305
2323
0 commit comments