Skip to content

Commit 11d2b7d

Browse files
authored
Cleanups (#321)
* Cleanup * Fix types
1 parent d2cc4a1 commit 11d2b7d

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

flox/core.py

+42-38
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Any,
1616
Callable,
1717
Literal,
18+
TypedDict,
1819
Union,
1920
overload,
2021
)
@@ -87,6 +88,17 @@
8788
DUMMY_AXIS = -2
8889

8990

91+
class FactorizeKwargs(TypedDict, total=False):
92+
"""Used in _factorize_multiple"""
93+
94+
by: T_Bys
95+
axes: T_Axes
96+
fastpath: bool
97+
expected_groups: T_ExpectIndexOptTuple | None
98+
reindex: bool
99+
sort: bool
100+
101+
90102
def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
91103
"""Account for numbagg not providing a fill_value kwarg."""
92104
from .aggregate_numbagg import DEFAULT_FILL_VALUE
@@ -1434,7 +1446,7 @@ def dask_groupby_agg(
14341446
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])
14351447

14361448
# tokenize here since by has already been hashed if its numpy
1437-
token = dask.base.tokenize(array, by, agg, expected_groups, axis)
1449+
token = dask.base.tokenize(array, by, agg, expected_groups, axis, method)
14381450

14391451
# preprocess the array:
14401452
# - for argreductions, this zips the index together with the array block
@@ -1454,7 +1466,8 @@ def dask_groupby_agg(
14541466
# b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
14551467
# This allows us to discover groups at compute time, support argreductions, lower intermediate
14561468
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1457-
do_simple_combine = not _is_arg_reduction(agg)
1469+
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
1470+
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
14581471

14591472
if method == "blockwise":
14601473
# use the "non dask" code path, but applied blockwise
@@ -1510,7 +1523,7 @@ def dask_groupby_agg(
15101523

15111524
tree_reduce = partial(
15121525
dask.array.reductions._tree_reduce,
1513-
name=f"{name}-reduce-{method}",
1526+
name=f"{name}-reduce",
15141527
dtype=array.dtype,
15151528
axis=axis,
15161529
keepdims=True,
@@ -1529,7 +1542,7 @@ def dask_groupby_agg(
15291542
combine=partial(combine, agg=agg),
15301543
aggregate=partial(aggregate, expected_groups=expected_groups, reindex=reindex),
15311544
)
1532-
if is_duck_dask_array(by_input) and expected_groups is None:
1545+
if labels_are_unknown:
15331546
groups = _extract_unknown_groups(reduced, dtype=by.dtype)
15341547
group_chunks = ((np.nan,),)
15351548
else:
@@ -1747,17 +1760,26 @@ def _convert_expected_groups_to_index(
17471760

17481761

17491762
def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:
1750-
group_idx, *rest = factorize_(by, **kwargs)
1763+
group_idx, *_ = factorize_(by, **kwargs)
17511764
return group_idx
17521765

17531766

17541767
def _factorize_multiple(
17551768
by: T_Bys,
17561769
expected_groups: T_ExpectIndexOptTuple,
17571770
any_by_dask: bool,
1758-
reindex: bool,
17591771
sort: bool = True,
17601772
) -> tuple[tuple[np.ndarray], tuple[np.ndarray, ...], tuple[int, ...]]:
1773+
kwargs: FactorizeKwargs = dict(
1774+
axes=(), # always (), we offset later if necessary.
1775+
expected_groups=expected_groups,
1776+
fastpath=True,
1777+
# This is the only way it makes sense I think.
1778+
# reindex controls what's actually allocated in chunk_reduce
1779+
# At this point, we care about an accurate conversion to codes.
1780+
reindex=True,
1781+
sort=sort,
1782+
)
17611783
if any_by_dask:
17621784
import dask.array
17631785

@@ -1771,11 +1793,7 @@ def _factorize_multiple(
17711793
*by_,
17721794
chunks=tuple(chunks.values()),
17731795
meta=np.array((), dtype=np.int64),
1774-
axes=(), # always (), we offset later if necessary.
1775-
expected_groups=expected_groups,
1776-
fastpath=True,
1777-
reindex=reindex,
1778-
sort=sort,
1796+
**kwargs,
17791797
)
17801798

17811799
fg, gs = [], []
@@ -1796,14 +1814,8 @@ def _factorize_multiple(
17961814
found_groups = tuple(fg)
17971815
grp_shape = tuple(gs)
17981816
else:
1799-
group_idx, found_groups, grp_shape, ngroups, size, props = factorize_(
1800-
by,
1801-
axes=(), # always (), we offset later if necessary.
1802-
expected_groups=expected_groups,
1803-
fastpath=True,
1804-
reindex=reindex,
1805-
sort=sort,
1806-
)
1817+
kwargs["by"] = by
1818+
group_idx, found_groups, grp_shape, *_ = factorize_(**kwargs)
18071819

18081820
return (group_idx,), found_groups, grp_shape
18091821

@@ -2058,7 +2070,7 @@ def groupby_reduce(
20582070
# (pd.IntervalIndex or not)
20592071
expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort)
20602072

2061-
# Don't factorize "early only when
2073+
# Don't factorize early only when
20622074
# grouping by dask arrays, and not having expected_groups
20632075
factorize_early = not (
20642076
# can't do it if we are grouping by dask array but don't have expected_groups
@@ -2069,10 +2081,6 @@ def groupby_reduce(
20692081
bys,
20702082
expected_groups,
20712083
any_by_dask=any_by_dask,
2072-
# This is the only way it makes sense I think.
2073-
# reindex controls what's actually allocated in chunk_reduce
2074-
# At this point, we care about an accurate conversion to codes.
2075-
reindex=True,
20762084
sort=sort,
20772085
)
20782086
expected_groups = (pd.RangeIndex(math.prod(grp_shape)),)
@@ -2103,21 +2111,17 @@ def groupby_reduce(
21032111
"along a single axis or when reducing across all dimensions of `by`."
21042112
)
21052113

2106-
# TODO: make sure expected_groups is unique
21072114
if nax == 1 and by_.ndim > 1 and expected_groups is None:
2108-
if not any_by_dask:
2109-
expected_groups = _get_expected_groups(by_, sort)
2110-
else:
2111-
# When we reduce along all axes, we are guaranteed to see all
2112-
# groups in the final combine stage, so everything works.
2113-
# This is not necessarily true when reducing along a subset of axes
2114-
# (of by)
2115-
# TODO: Does this depend on chunking of by?
2116-
# For e.g., we could relax this if there is only one chunk along all
2117-
# by dim != axis?
2118-
raise NotImplementedError(
2119-
"Please provide ``expected_groups`` when not reducing along all axes."
2120-
)
2115+
# When we reduce along all axes, we are guaranteed to see all
2116+
# groups in the final combine stage, so everything works.
2117+
# This is not necessarily true when reducing along a subset of axes
2118+
# (of by)
2119+
# TODO: Does this depend on chunking of by?
2120+
# For e.g., we could relax this if there is only one chunk along all
2121+
# by dim != axis?
2122+
raise NotImplementedError(
2123+
"Please provide ``expected_groups`` when not reducing along all axes."
2124+
)
21212125

21222126
assert nax <= by_.ndim
21232127
if nax < by_.ndim:

0 commit comments

Comments
 (0)