Skip to content

Commit e2aa2be

Browse files
authored
Single layer for cohorts. (#415)
Does so by vendoring in the tree reduction code and modifying to work purely on a Graph level with no array creation. For kicks, I fused in the concatenation step. Benchmarks are looking great! I don't understand the `ERA5MonthHourRechunked` regression, but it's quite minor ``` | Before [df81| After [761367d] | Ratio | Benchmark (Parameter) | |-------------|------------------|---------|----------------------------------------------------------| | 3672 | 4266 | 1.16 | cohorts.ERA5MonthHourRechunked.track_num_tasks_optimized | | 3834 | 3469 | 0.9 | cohorts.ERA5DayOfYear.track_num_tasks | | 999±20ms | 822±60ms | 0.82 | cohorts.ERA5Resampling.time_graph_construct | | 11 | 6 | 0.55 | cohorts.ERA5MonthHourRechunked.track_num_layers | | 10 | 5 | 0.5 | cohorts.ERA5MonthHour.track_num_layers | | 17 | 5 | 0.29 | cohorts.PerfectMonthly.track_num_layers | | 128 | 5 | 0.04 | cohorts.ERA5Google.track_num_layers | | 266 | 6 | 0.02 | cohorts.NWMMidwest.track_num_layers | | 735 | 5 | 0.01 | cohorts.ERA5DayOfYear.track_num_layers | | 490 | 6 | 0.01 | cohorts.OISST.track_num_layers | | 7305 | 5 | 0 | cohorts.ERA5Resampling.track_num_layers | ```
1 parent cb5d203 commit e2aa2be

File tree

9 files changed

+198
-45
lines changed

9 files changed

+198
-45
lines changed

.github/workflows/benchmarks.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ jobs:
2626
with:
2727
environment-name: flox-bench
2828
create-args: >-
29-
python=3.10
29+
python=3.12
3030
asv
3131
mamba
32+
libmambapy<2.0
33+
conda-build
3234
init-shell: bash
3335
cache-environment: true
3436

asv_bench/benchmarks/cohorts.py

+16
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,22 @@ def rechunk(self):
113113
)
114114

115115

116+
class ERA5Resampling(Cohorts):
117+
def setup(self, *args, **kwargs):
118+
super().__init__()
119+
# nyears is number of years, adjust to make bigger,
120+
# full dataset is 60-ish years.
121+
nyears = 5
122+
shape = (37, 721, 1440, nyears * 365 * 24)
123+
chunks = (-1, -1, -1, 1)
124+
time = pd.date_range("2001-01-01", periods=shape[-1], freq="h")
125+
126+
self.array = dask.array.random.random(shape, chunks=chunks)
127+
self.by = codes_for_resampling(time, "D")
128+
self.axis = (-1,)
129+
self.expected = np.unique(self.by)
130+
131+
116132
class ERA5DayOfYear(ERA5Dataset, Cohorts):
117133
def setup(self, *args, **kwargs):
118134
super().__init__()

ci/benchmark.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ dependencies:
66
- build
77
- cachey
88
- dask-core
9-
- numpy<2
9+
- numpy<2.1
1010
- mamba
1111
- pip
12-
- python=3.10
1312
- xarray
1413
- numpy_groupies>=0.9.19
1514
- numbagg>=0.3

flox/core.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
quantile_new_dims_func,
4545
)
4646
from .cache import memoize
47+
from .lib import ArrayLayer
4748
from .xrutils import (
4849
_contains_cftime_datetimes,
4950
_to_pytimedelta,
@@ -72,10 +73,7 @@
7273
from typing import Unpack
7374
except (ModuleNotFoundError, ImportError):
7475
Unpack: Any # type: ignore[no-redef]
75-
76-
import cubed.Array as CubedArray
77-
import dask.array.Array as DaskArray
78-
from dask.typing import Graph
76+
from .types import CubedArray, DaskArray, Graph
7977

8078
T_DuckArray: TypeAlias = np.ndarray | DaskArray | CubedArray # Any ?
8179
T_By: TypeAlias = T_DuckArray
@@ -1191,7 +1189,7 @@ def _aggregate(
11911189
agg: Aggregation,
11921190
expected_groups: pd.Index | None,
11931191
axis: T_Axes,
1194-
keepdims,
1192+
keepdims: bool,
11951193
fill_value: Any,
11961194
reindex: bool,
11971195
) -> FinalResultsDict:
@@ -1511,7 +1509,7 @@ def subset_to_blocks(
15111509
blkshape: tuple[int, ...] | None = None,
15121510
reindexer=identity,
15131511
chunks_as_array: tuple[np.ndarray, ...] | None = None,
1514-
) -> DaskArray:
1512+
) -> ArrayLayer:
15151513
"""
15161514
Advanced indexing of .blocks such that we always get a regular array back.
15171515
@@ -1525,10 +1523,8 @@ def subset_to_blocks(
15251523
-------
15261524
dask.array
15271525
"""
1528-
import dask.array
15291526
from dask.array.slicing import normalize_index
15301527
from dask.base import tokenize
1531-
from dask.highlevelgraph import HighLevelGraph
15321528

15331529
if blkshape is None:
15341530
blkshape = array.blocks.shape
@@ -1538,9 +1534,6 @@ def subset_to_blocks(
15381534

15391535
index = _normalize_indexes(array, flatblocks, blkshape)
15401536

1541-
if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
1542-
return dask.array.map_blocks(reindexer, array, meta=array._meta)
1543-
15441537
# These rest is copied from dask.array.core.py with slight modifications
15451538
index = normalize_index(index, array.numblocks)
15461539
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
@@ -1553,10 +1546,7 @@ def subset_to_blocks(
15531546

15541547
keys = itertools.product(*(range(len(c)) for c in chunks))
15551548
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}
1556-
1557-
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])
1558-
1559-
return dask.array.Array(graph, name, chunks, meta=array)
1549+
return ArrayLayer(layer=layer, chunks=chunks, name=name)
15601550

15611551

15621552
def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
@@ -1613,6 +1603,9 @@ def dask_groupby_agg(
16131603
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
16141604
import dask.array
16151605
from dask.array.core import slices_from_chunks
1606+
from dask.highlevelgraph import HighLevelGraph
1607+
1608+
from .dask_array_ops import _tree_reduce
16161609

16171610
# I think _tree_reduce expects this
16181611
assert isinstance(axis, Sequence)
@@ -1742,35 +1735,44 @@ def dask_groupby_agg(
17421735
assert chunks_cohorts
17431736
block_shape = array.blocks.shape[-len(axis) :]
17441737

1745-
reduced_ = []
1738+
out_name = f"{name}-reduce-{method}-{token}"
17461739
groups_ = []
17471740
chunks_as_array = tuple(np.array(c) for c in array.chunks)
1748-
for blks, cohort in chunks_cohorts.items():
1741+
dsk: Graph = {}
1742+
for icohort, (blks, cohort) in enumerate(chunks_cohorts.items()):
17491743
cohort_index = pd.Index(cohort)
17501744
reindexer = (
17511745
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
17521746
if do_simple_combine
17531747
else identity
17541748
)
1755-
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array)
1749+
subset = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array)
1750+
dsk |= subset.layer # type: ignore[operator]
17561751
# now that we have reindexed, we can set reindex=True explicitlly
1757-
reduced_.append(
1758-
tree_reduce(
1759-
reindexed,
1760-
combine=partial(combine, agg=agg, reindex=do_simple_combine),
1761-
aggregate=partial(
1762-
aggregate,
1763-
expected_groups=cohort_index,
1764-
reindex=do_simple_combine,
1765-
),
1766-
)
1752+
_tree_reduce(
1753+
subset,
1754+
out_dsk=dsk,
1755+
name=out_name,
1756+
block_index=icohort,
1757+
axis=axis,
1758+
combine=partial(combine, agg=agg, reindex=do_simple_combine, keepdims=True),
1759+
aggregate=partial(
1760+
aggregate, expected_groups=cohort_index, reindex=do_simple_combine, keepdims=True
1761+
),
17671762
)
17681763
# This is done because pandas promotes to 64-bit types when an Index is created
17691764
# So we use the index to generate the return value for consistency with "map-reduce"
17701765
# This is important on windows
17711766
groups_.append(cohort_index.values)
17721767

1773-
reduced = dask.array.concatenate(reduced_, axis=-1)
1768+
graph = HighLevelGraph.from_collections(out_name, dsk, dependencies=[intermediate])
1769+
1770+
out_chunks = list(array.chunks)
1771+
out_chunks[axis[-1]] = tuple(len(c) for c in chunks_cohorts.values())
1772+
for ax in axis[:-1]:
1773+
out_chunks[ax] = (1,)
1774+
reduced = dask.array.Array(graph, out_name, out_chunks, meta=array._meta)
1775+
17741776
groups = (np.concatenate(groups_),)
17751777
group_chunks = (tuple(len(cohort) for cohort in groups_),)
17761778

flox/dask_array_ops.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import builtins
2+
import math
3+
from functools import partial
4+
from itertools import product
5+
from numbers import Integral
6+
7+
from dask import config
8+
from dask.blockwise import lol_tuples
9+
from toolz import partition_all
10+
11+
from .lib import ArrayLayer
12+
from .types import Graph
13+
14+
15+
# _tree_reduce and partial_reduce are copied from dask.array.reductions
16+
# They have been modified to work purely with graphs, and without creating new Array layers
17+
# in the graph. The `block_index` kwarg is new and avoids a concatenation by simply setting the right
18+
# key initially
19+
def _tree_reduce(
20+
x: ArrayLayer,
21+
*,
22+
name: str,
23+
out_dsk: Graph,
24+
aggregate,
25+
axis: tuple[int, ...],
26+
block_index: int,
27+
split_every=None,
28+
combine=None,
29+
):
30+
# Normalize split_every
31+
split_every = split_every or config.get("split_every", 4)
32+
if isinstance(split_every, dict):
33+
split_every = {k: split_every.get(k, 2) for k in axis}
34+
elif isinstance(split_every, Integral):
35+
n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
36+
split_every = dict.fromkeys(axis, n)
37+
else:
38+
raise ValueError("split_every must be a int or a dict")
39+
40+
numblocks = tuple(len(c) for c in x.chunks)
41+
out_chunks = x.chunks
42+
43+
# Reduce across intermediates
44+
depth = 1
45+
for i, n in enumerate(numblocks):
46+
if i in split_every and split_every[i] != 1:
47+
depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
48+
func = partial(combine or aggregate, axis=axis)
49+
50+
agg_dep_name = x.name
51+
for level in range(depth - 1):
52+
newname = name + f"-{block_index}-partial-{level}"
53+
out_dsk, out_chunks = partial_reduce(
54+
func,
55+
out_dsk,
56+
chunks=out_chunks,
57+
split_every=split_every,
58+
name=newname,
59+
dep_name=agg_dep_name,
60+
axis=axis,
61+
)
62+
agg_dep_name = newname
63+
func = partial(aggregate, axis=axis)
64+
return partial_reduce(
65+
func,
66+
out_dsk,
67+
chunks=out_chunks,
68+
split_every=split_every,
69+
name=name,
70+
dep_name=agg_dep_name,
71+
axis=axis,
72+
block_index=block_index,
73+
)
74+
75+
76+
def partial_reduce(
77+
func,
78+
dsk,
79+
*,
80+
chunks: tuple[tuple[int, ...], ...],
81+
name: str,
82+
dep_name: str,
83+
split_every: dict[int, int],
84+
axis: tuple[int, ...],
85+
block_index: int | None = None,
86+
):
87+
numblocks = tuple(len(c) for c in chunks)
88+
ndim = len(numblocks)
89+
parts = [list(partition_all(split_every.get(i, 1), range(n))) for (i, n) in enumerate(numblocks)]
90+
keys = product(*map(range, map(len, parts)))
91+
out_chunks = [
92+
tuple(1 for p in partition_all(split_every[i], c)) if i in split_every else c
93+
for (i, c) in enumerate(chunks)
94+
]
95+
for k, p in zip(keys, product(*parts)):
96+
free = {i: j[0] for (i, j) in enumerate(p) if len(j) == 1 and i not in split_every}
97+
dummy = dict(i for i in enumerate(p) if i[0] in split_every)
98+
g = lol_tuples((dep_name,), range(ndim), free, dummy)
99+
assert dep_name != name
100+
if block_index is not None:
101+
k = (*k[:-1], block_index)
102+
dsk[(name,) + k] = (func, g)
103+
return dsk, out_chunks

flox/lib.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from dataclasses import dataclass
2+
3+
from .types import DaskArray, Graph
4+
5+
6+
@dataclass
7+
class ArrayLayer:
8+
name: str
9+
layer: Graph
10+
chunks: tuple[tuple[int, ...], ...]
11+
12+
def to_array(self, dep: DaskArray) -> DaskArray:
13+
from dask.array import Array
14+
from dask.highlevelgraph import HighLevelGraph
15+
16+
graph = HighLevelGraph.from_collections(self.name, self.layer, dependencies=[dep])
17+
return Array(graph, self.name, self.chunks, meta=dep._meta)

flox/types.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Any, TypeAlias
2+
3+
try:
4+
import cubed.Array as CubedArray
5+
except ImportError:
6+
CubedArray = Any
7+
8+
try:
9+
import dask.array.Array as DaskArray
10+
from dask.typing import Graph
11+
except ImportError:
12+
DaskArray = Any
13+
Graph: TypeAlias = Any # type: ignore[no-redef,misc]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ module=[
133133
"pandas",
134134
"setuptools",
135135
"scipy.*",
136-
"toolz",
136+
"toolz.*",
137137
]
138138
ignore_missing_imports = true
139139

tests/test_core.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -1524,15 +1524,6 @@ def test_dtype(func, dtype, engine):
15241524
assert actual.dtype == np.dtype("float64")
15251525

15261526

1527-
@requires_dask
1528-
def test_subset_blocks():
1529-
array = dask.array.random.random((120,), chunks=(4,))
1530-
1531-
blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27)
1532-
subset = subset_to_blocks(array, blockid)
1533-
assert subset.blocks.shape == (len(blockid),)
1534-
1535-
15361527
@requires_dask
15371528
@pytest.mark.parametrize(
15381529
"flatblocks, expected",
@@ -1573,19 +1564,29 @@ def test_normalize_block_indexing_2d(flatblocks, expected):
15731564
assert_equal_tuple(expected, actual)
15741565

15751566

1567+
@requires_dask
1568+
def test_subset_blocks():
1569+
array = dask.array.random.random((120,), chunks=(4,))
1570+
1571+
blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27)
1572+
subset = subset_to_blocks(array, blockid).to_array(array)
1573+
assert subset.blocks.shape == (len(blockid),)
1574+
1575+
1576+
@pytest.mark.skip("temporarily removed this optimization")
15761577
@requires_dask
15771578
def test_subset_block_passthrough():
15781579
from flox.core import identity
15791580

15801581
# full slice pass through
15811582
array = dask.array.ones((5,), chunks=(1,))
15821583
expected = dask.array.map_blocks(identity, array)
1583-
subset = subset_to_blocks(array, np.arange(5))
1584+
subset = subset_to_blocks(array, np.arange(5)).to_array(array)
15841585
assert subset.name == expected.name
15851586

15861587
array = dask.array.ones((5, 5), chunks=1)
15871588
expected = dask.array.map_blocks(identity, array)
1588-
subset = subset_to_blocks(array, np.arange(25))
1589+
subset = subset_to_blocks(array, np.arange(25)).to_array(array)
15891590
assert subset.name == expected.name
15901591

15911592

@@ -1604,7 +1605,7 @@ def test_subset_block_passthrough():
16041605
)
16051606
def test_subset_block_2d(flatblocks, expectidx):
16061607
array = dask.array.from_array(np.arange(25).reshape((5, 5)), chunks=1)
1607-
subset = subset_to_blocks(array, flatblocks)
1608+
subset = subset_to_blocks(array, flatblocks).to_array(array)
16081609
assert len(subset.dask.layers) == 2
16091610
assert_equal(subset, array.compute()[expectidx])
16101611

0 commit comments

Comments
 (0)