Skip to content

Commit 8dac463

Browse files
authored
Parallelize ravel-multi-index (#433)
* Refactor out factorize loop * threadpool * Split out ravel_multi_index bits * Dask-ify ravel multi index * cleanup * Types
1 parent f8cfb5d commit 8dac463

File tree

2 files changed

+110
-90
lines changed

2 files changed

+110
-90
lines changed

flox/core.py

Lines changed: 103 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ class FactorizeKwargs(TypedDict, total=False):
195195
by: T_Bys
196196
axes: T_Axes
197197
fastpath: bool
198-
expected_groups: T_ExpectIndexOptTuple | None
199198
reindex: bool
200199
sort: bool
201200

@@ -844,6 +843,67 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
844843
return offset, size
845844

846845

846+
def _factorize_single(by, expect, *, sort: bool, reindex: bool):
847+
flat = by.reshape(-1)
848+
if isinstance(expect, pd.RangeIndex):
849+
# idx is a view of the original `by` array
850+
# copy here so we don't have a race condition with the
851+
# group_idx[nanmask] = nan_sentinel assignment later
852+
# this is important in shared-memory parallelism with dask
853+
# TODO: figure out how to avoid this
854+
idx = flat.copy()
855+
found_groups = np.array(expect)
856+
# TODO: fix by using masked integers
857+
idx[idx > expect[-1]] = -1
858+
859+
elif isinstance(expect, pd.IntervalIndex):
860+
if expect.closed == "both":
861+
raise NotImplementedError
862+
bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]])
863+
864+
# digitize is 0 or idx.max() for values outside the bounds of all intervals
865+
# make it behave like pd.cut which uses -1:
866+
if len(bins) > 1:
867+
right = expect.closed_right
868+
idx = np.digitize(
869+
flat,
870+
bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins,
871+
right=right,
872+
)
873+
idx -= 1
874+
within_bins = flat <= bins.max() if right else flat < bins.max()
875+
idx[~within_bins] = -1
876+
else:
877+
idx = np.zeros_like(flat, dtype=np.intp) - 1
878+
found_groups = np.array(expect)
879+
else:
880+
if expect is not None and reindex:
881+
sorter = np.argsort(expect)
882+
groups = expect[(sorter,)] if sort else expect
883+
idx = np.searchsorted(expect, flat, sorter=sorter)
884+
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
885+
if not sort:
886+
# idx is the index in to the sorted array.
887+
# if we didn't want sorting, unsort it back
888+
idx[(idx == len(expect),)] = -1
889+
idx = sorter[(idx,)]
890+
idx[mask] = -1
891+
else:
892+
idx, groups = pd.factorize(flat, sort=sort)
893+
found_groups = np.array(groups)
894+
895+
return (found_groups, idx.reshape(by.shape))
896+
897+
898+
def _ravel_factorized(*factorized: np.ndarray, grp_shape: tuple[int, ...]) -> np.ndarray:
899+
group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap")
900+
# NaNs; as well as values outside the bins are coded by -1
901+
# Restore these after the raveling
902+
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
903+
group_idx[nan_by_mask] = -1
904+
return group_idx
905+
906+
847907
@overload
848908
def factorize_(
849909
by: T_Bys,
@@ -890,7 +950,7 @@ def factorize_(
890950
fastpath: bool = False,
891951
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, FactorProps | None]:
892952
"""
893-
Returns an array of integer codes for groups (and associated data)
953+
Returns an array of integer codes for groups (and associated data)
894954
by wrapping pd.cut and pd.factorize (depending on isbin).
895955
This method handles reindex and sort so that we don't spend time reindexing / sorting
896956
a possibly large results array. Instead we set up the appropriate integer codes (group_idx)
@@ -899,75 +959,32 @@ def factorize_(
899959
if expected_groups is None:
900960
expected_groups = (None,) * len(by)
901961

902-
factorized = []
903-
found_groups = []
904-
for groupvar, expect in zip(by, expected_groups):
905-
flat = groupvar.reshape(-1)
906-
if isinstance(expect, pd.RangeIndex):
907-
# idx is a view of the original `by` array
908-
# copy here so we don't have a race condition with the
909-
# group_idx[nanmask] = nan_sentinel assignment later
910-
# this is important in shared-memory parallelism with dask
911-
# TODO: figure out how to avoid this
912-
idx = flat.copy()
913-
found_groups.append(np.array(expect))
914-
# TODO: fix by using masked integers
915-
idx[idx > expect[-1]] = -1
916-
917-
elif isinstance(expect, pd.IntervalIndex):
918-
if expect.closed == "both":
919-
raise NotImplementedError
920-
bins = np.concatenate([expect.left.to_numpy(), expect.right.to_numpy()[[-1]]])
921-
922-
# digitize is 0 or idx.max() for values outside the bounds of all intervals
923-
# make it behave like pd.cut which uses -1:
924-
if len(bins) > 1:
925-
right = expect.closed_right
926-
idx = np.digitize(
927-
flat,
928-
bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins,
929-
right=right,
930-
)
931-
idx -= 1
932-
within_bins = flat <= bins.max() if right else flat < bins.max()
933-
idx[~within_bins] = -1
934-
else:
935-
idx = np.zeros_like(flat, dtype=np.intp) - 1
936-
937-
found_groups.append(np.array(expect))
938-
else:
939-
if expect is not None and reindex:
940-
sorter = np.argsort(expect)
941-
groups = expect[(sorter,)] if sort else expect
942-
idx = np.searchsorted(expect, flat, sorter=sorter)
943-
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
944-
if not sort:
945-
# idx is the index in to the sorted array.
946-
# if we didn't want sorting, unsort it back
947-
idx[(idx == len(expect),)] = -1
948-
idx = sorter[(idx,)]
949-
idx[mask] = -1
950-
else:
951-
idx, groups = pd.factorize(flat, sort=sort)
952-
953-
found_groups.append(np.array(groups))
954-
factorized.append(idx.reshape(groupvar.shape))
962+
if len(by) > 2:
963+
with ThreadPoolExecutor() as executor:
964+
futures = [
965+
executor.submit(partial(_factorize_single, sort=sort, reindex=reindex), groupvar, expect)
966+
for groupvar, expect in zip(by, expected_groups)
967+
]
968+
results = tuple(f.result() for f in futures)
969+
else:
970+
results = tuple(
971+
_factorize_single(groupvar, expect, sort=sort, reindex=reindex)
972+
for groupvar, expect in zip(by, expected_groups)
973+
)
974+
found_groups = [r[0] for r in results]
975+
factorized = [r[1] for r in results]
955976

956977
grp_shape = tuple(len(grp) for grp in found_groups)
957978
ngroups = math.prod(grp_shape)
958979
if len(by) > 1:
959-
group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap")
960-
# NaNs; as well as values outside the bins are coded by -1
961-
# Restore these after the raveling
962-
nan_by_mask = reduce(np.logical_or, [(f == -1) for f in factorized])
963-
group_idx[nan_by_mask] = -1
980+
group_idx = _ravel_factorized(*factorized, grp_shape=grp_shape)
964981
else:
965-
group_idx = factorized[0]
982+
(group_idx,) = factorized
966983

967984
if fastpath:
968985
return group_idx, tuple(found_groups), grp_shape, ngroups, ngroups, None
969986

970-
if len(axes) == 1 and groupvar.ndim > 1:
987+
if len(axes) == 1 and by[0].ndim > 1:
971988
# Not reducing along all dimensions of by
972989
# this is OK because for 3D by and axis=(1,2),
973990
# we collapse to a 2D by and axis=-1
@@ -2258,7 +2275,6 @@ def _factorize_multiple(
22582275
) -> tuple[tuple[np.ndarray], tuple[np.ndarray, ...], tuple[int, ...]]:
22592276
kwargs: FactorizeKwargs = dict(
22602277
axes=(), # always (), we offset later if necessary.
2261-
expected_groups=expected_groups,
22622278
fastpath=True,
22632279
# This is the only way it makes sense I think.
22642280
# reindex controls what's actually allocated in chunk_reduce
@@ -2272,34 +2288,36 @@ def _factorize_multiple(
22722288
# unifying chunks will make sure all arrays in `by` are dask arrays
22732289
# with compatible chunks, even if there was originally a numpy array
22742290
inds = tuple(range(by[0].ndim))
2275-
chunks, by_ = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))
2276-
2277-
group_idx = dask.array.map_blocks(
2278-
_lazy_factorize_wrapper,
2279-
*by_,
2280-
chunks=tuple(chunks.values()),
2281-
meta=np.array((), dtype=np.int64),
2282-
**kwargs,
2283-
)
2284-
2285-
fg, gs = [], []
22862291
for by_, expect in zip(by, expected_groups):
2287-
if expect is None:
2288-
if is_duck_dask_array(by_):
2289-
raise ValueError("Please provide expected_groups when grouping by a dask array.")
2292+
if expect is None and is_duck_dask_array(by_):
2293+
raise ValueError("Please provide expected_groups when grouping by a dask array.")
22902294

2291-
found_group = pd.unique(by_.reshape(-1))
2292-
else:
2293-
found_group = expect.to_numpy()
2295+
found_groups = tuple(
2296+
pd.unique(by_.reshape(-1)) if expect is None else expect.to_numpy()
2297+
for by_, expect in zip(by, expected_groups)
2298+
)
2299+
grp_shape = tuple(map(len, found_groups))
22942300

2295-
fg.append(found_group)
2296-
gs.append(len(found_group))
2301+
chunks, by_chunked = dask.array.unify_chunks(*itertools.chain(*zip(by, (inds,) * len(by))))
2302+
group_idxs = [
2303+
dask.array.map_blocks(
2304+
_lazy_factorize_wrapper,
2305+
by_,
2306+
expected_groups=(expect_,),
2307+
meta=np.array((), dtype=np.int64),
2308+
**kwargs,
2309+
)
2310+
for by_, expect_ in zip(by_chunked, expected_groups)
2311+
]
2312+
# This could be avoied but we'd use `np.where`
2313+
# instead `_ravel_factorized` instead i.e. a copy.
2314+
group_idx = dask.array.map_blocks(
2315+
_ravel_factorized, *group_idxs, grp_shape=grp_shape, chunks=tuple(chunks.values()), dtype=np.int64
2316+
)
22972317

2298-
found_groups = tuple(fg)
2299-
grp_shape = tuple(gs)
23002318
else:
23012319
kwargs["by"] = by
2302-
group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs)
2320+
group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs, expected_groups=expected_groups)
23032321

23042322
return (group_idx,), found_groups, grp_shape
23052323

flox/dask_array_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@
44
from itertools import product
55
from numbers import Integral
66

7+
import dask
78
import pandas as pd
89
from dask import config
910
from dask.base import normalize_token
1011
from dask.blockwise import lol_tuples
12+
from packaging.version import Version
1113
from toolz import partition_all
1214

1315
from .lib import ArrayLayer
1416
from .types import Graph
1517

16-
17-
# workaround for https://github.com/dask/dask/issues/11862
18-
@normalize_token.register(pd.RangeIndex)
19-
def normalize_range_index(x):
20-
return normalize_token(type(x)), x.start, x.stop, x.step, x.dtype, x.name
18+
if Version(dask.__version__) <= Version("2025.03.1"):
19+
# workaround for https://github.com/dask/dask/issues/11862
20+
@normalize_token.register(pd.RangeIndex)
21+
def normalize_range_index(x):
22+
return normalize_token(type(x)), x.start, x.stop, x.step, x.dtype, x.name
2123

2224

2325
# _tree_reduce and partial_reduce are copied from dask.array.reductions

0 commit comments

Comments
 (0)