Skip to content

Commit 6897240

Browse files
authored
Major fix to subset_to_blocks (#173)
1. Copy and extend `dask.array.blocks.__getitem__` to support orthogonal indexing. This means each cohort is a single layer in the graph. 2. Significantly extend cohorts.py benchmarks
1 parent 0bf35e0 commit 6897240

File tree

4 files changed

+239
-38
lines changed

4 files changed

+239
-38
lines changed

asv_bench/benchmarks/cohorts.py

+72-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,22 @@ def track_num_tasks(self):
2929
)[0]
3030
return len(result.dask.to_dict())
3131

32+
def track_num_tasks_optimized(self):
33+
result = flox.groupby_reduce(
34+
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
35+
)[0]
36+
(opt,) = dask.optimize(result)
37+
return len(opt.dask.to_dict())
38+
39+
def track_num_layers(self):
40+
result = flox.groupby_reduce(
41+
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
42+
)[0]
43+
return len(result.dask.layers)
44+
3245
track_num_tasks.unit = "tasks"
46+
track_num_tasks_optimized.unit = "tasks"
47+
track_num_layers.unit = "layers"
3348

3449

3550
class NWMMidwest(Cohorts):
@@ -45,16 +60,68 @@ def setup(self, *args, **kwargs):
4560
self.axis = (-2, -1)
4661

4762

48-
class ERA5(Cohorts):
63+
class ERA5Dataset:
4964
"""ERA5"""
5065

66+
def __init__(self, *args, **kwargs):
67+
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
68+
self.axis = (-1,)
69+
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))
70+
71+
def rechunk(self):
72+
self.array = flox.core.rechunk_for_cohorts(
73+
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
74+
)
75+
76+
77+
class ERA5DayOfYear(ERA5Dataset, Cohorts):
78+
def setup(self, *args, **kwargs):
79+
super().__init__()
80+
self.by = self.time.dt.dayofyear.values
81+
82+
83+
class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
84+
def setup(self, *args, **kwargs):
85+
super().setup()
86+
super().rechunk()
87+
88+
89+
class ERA5MonthHour(ERA5Dataset, Cohorts):
5190
def setup(self, *args, **kwargs):
52-
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
91+
super().__init__()
92+
by = (self.time.dt.month.values, self.time.dt.hour.values)
93+
ret = flox.core._factorize_multiple(
94+
by,
95+
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
96+
by_is_dask=False,
97+
reindex=False,
98+
)
99+
# Add one so the rechunk code is simpler and makes sense
100+
self.by = ret[0][0] + 1
53101

54-
self.by = time.dt.dayofyear.values
102+
103+
class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts):
104+
def setup(self, *args, **kwargs):
105+
super().setup()
106+
super().rechunk()
107+
108+
109+
class PerfectMonthly(Cohorts):
110+
"""Perfectly chunked for a "cohorts" monthly mean climatology"""
111+
112+
def setup(self, *args, **kwargs):
113+
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
55114
self.axis = (-1,)
115+
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
116+
self.by = self.time.dt.month.values
56117

57-
array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
118+
def rechunk(self):
58119
self.array = flox.core.rechunk_for_cohorts(
59-
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
120+
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True
60121
)
122+
123+
124+
class PerfectMonthlyRechunked(PerfectMonthly):
125+
def setup(self, *args, **kwargs):
126+
super().setup()
127+
super().rechunk()

flox/core.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
from collections import namedtuple
88
from functools import partial, reduce
9+
from numbers import Integral
910
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union
1011

1112
import numpy as np
@@ -288,7 +289,7 @@ def rechunk_for_cohorts(
288289
divisions = []
289290
counter = 1
290291
for idx, lab in enumerate(labels):
291-
if lab in force_new_chunk_at:
292+
if lab in force_new_chunk_at or idx == 0:
292293
divisions.append(idx)
293294
counter = 1
294295
continue
@@ -305,6 +306,7 @@ def rechunk_for_cohorts(
305306
divisions.append(idx)
306307
counter = 1
307308
continue
309+
308310
counter += 1
309311

310312
divisions.append(len(labels))
@@ -313,6 +315,9 @@ def rechunk_for_cohorts(
313315
print(labels_at_breaks[:40])
314316

315317
newchunks = tuple(np.diff(divisions))
318+
if debug:
319+
print(divisions[:10], newchunks[:10])
320+
print(divisions[-10:], newchunks[-10:])
316321
assert sum(newchunks) == len(labels)
317322

318323
if newchunks == array.chunks[axis]:
@@ -1046,26 +1051,18 @@ def _reduce_blockwise(
10461051
return result
10471052

10481053

1049-
def subset_to_blocks(
1050-
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
1051-
) -> DaskArray:
1054+
def _normalize_indexes(array, flatblocks, blkshape):
10521055
"""
1053-
Advanced indexing of .blocks such that we always get a regular array back.
1056+
.blocks accessor can only accept one iterable at a time,
1057+
but can handle multiple slices.
1058+
To minimize tasks and layers, we normalize to produce slices
1059+
along as many axes as possible, and then repeatedly apply
1060+
any remaining iterables in a loop.
10541061
1055-
Parameters
1056-
----------
1057-
array : dask.array
1058-
flatblocks : flat indices of blocks to extract
1059-
blkshape : shape of blocks with which to unravel flatblocks
1060-
1061-
Returns
1062-
-------
1063-
dask.array
1062+
TODO: move this upstream
10641063
"""
1065-
if blkshape is None:
1066-
blkshape = array.blocks.shape
1067-
10681064
unraveled = np.unravel_index(flatblocks, blkshape)
1065+
10691066
normalized: list[Union[int, np.ndarray, slice]] = []
10701067
for ax, idx in enumerate(unraveled):
10711068
i = _unique(idx).squeeze()
@@ -1077,30 +1074,65 @@ def subset_to_blocks(
10771074
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
10781075
normalized.append(slice(i[0], i[-1] + 1))
10791076
else:
1080-
normalized.append(i)
1077+
normalized.append(list(i))
10811078
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)
10821079

10831080
# has no iterables
1084-
noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
1081+
noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
10851082
# has all iterables
1086-
alliter = {
1087-
ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
1088-
}
1083+
alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")}
10891084

1090-
# apply everything but the iterables
1091-
if all(i == slice(None) for i in noiter):
1085+
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values())))
1086+
1087+
full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter))
1088+
1089+
return full_tuple
1090+
1091+
1092+
def subset_to_blocks(
1093+
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
1094+
) -> DaskArray:
1095+
"""
1096+
Advanced indexing of .blocks such that we always get a regular array back.
1097+
1098+
Parameters
1099+
----------
1100+
array : dask.array
1101+
flatblocks : flat indices of blocks to extract
1102+
blkshape : shape of blocks with which to unravel flatblocks
1103+
1104+
Returns
1105+
-------
1106+
dask.array
1107+
"""
1108+
import dask.array
1109+
from dask.array.slicing import normalize_index
1110+
from dask.base import tokenize
1111+
from dask.highlevelgraph import HighLevelGraph
1112+
1113+
if blkshape is None:
1114+
blkshape = array.blocks.shape
1115+
1116+
index = _normalize_indexes(array, flatblocks, blkshape)
1117+
1118+
if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
10921119
return array
10931120

1094-
subset = array.blocks[noiter]
1121+
# These rest is copied from dask.array.core.py with slight modifications
1122+
index = normalize_index(index, array.numblocks)
1123+
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
10951124

1096-
for ax, inds in alliter.items():
1097-
if isinstance(inds, slice):
1098-
continue
1099-
idxr = [slice(None, None)] * array.ndim
1100-
idxr[ax] = inds
1101-
subset = subset.blocks[tuple(idxr)]
1125+
name = "blocks-" + tokenize(array, index)
1126+
new_keys = array._key_array[index]
1127+
1128+
squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
1129+
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))
1130+
1131+
keys = itertools.product(*(range(len(c)) for c in chunks))
1132+
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
1133+
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])
11021134

1103-
return subset
1135+
return dask.array.Array(graph, name, chunks, meta=array)
11041136

11051137

11061138
def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:

tests/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,18 @@ def assert_equal(a, b, tolerance=None):
115115
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)
116116

117117

118+
def assert_equal_tuple(a, b):
119+
"""assert_equal for .blocks indexing tuples"""
120+
assert len(a) == len(b)
121+
122+
for a_, b_ in zip(a, b):
123+
assert type(a_) == type(b_)
124+
if isinstance(a_, np.ndarray):
125+
np.testing.assert_array_equal(a_, b_)
126+
else:
127+
assert a_ == b_
128+
129+
118130
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
119131
def engine(request):
120132
if request.param == "numba":

tests/test_core.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,23 @@
1212
from flox.core import (
1313
_convert_expected_groups_to_index,
1414
_get_optimal_chunks_for_groups,
15+
_normalize_indexes,
1516
factorize_,
1617
find_group_cohorts,
1718
groupby_reduce,
1819
rechunk_for_cohorts,
1920
reindex_,
21+
subset_to_blocks,
2022
)
2123

22-
from . import assert_equal, engine, has_dask, raise_if_dask_computes, requires_dask
24+
from . import (
25+
assert_equal,
26+
assert_equal_tuple,
27+
engine,
28+
has_dask,
29+
raise_if_dask_computes,
30+
requires_dask,
31+
)
2332

2433
labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
2534
nan_labels = labels.astype(float) # copy
@@ -1035,3 +1044,84 @@ def test_dtype(func, dtype, engine):
10351044
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
10361045
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
10371046
assert actual.dtype == np.dtype("float64")
1047+
1048+
1049+
@requires_dask
1050+
def test_subset_blocks():
1051+
array = dask.array.random.random((120,), chunks=(4,))
1052+
1053+
blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27)
1054+
subset = subset_to_blocks(array, blockid)
1055+
assert subset.blocks.shape == (len(blockid),)
1056+
1057+
1058+
@requires_dask
1059+
@pytest.mark.parametrize(
1060+
"flatblocks, expected",
1061+
(
1062+
((0, 1, 2, 3, 4), (slice(None),)),
1063+
((1, 2, 3), (slice(1, 4),)),
1064+
((1, 3), ([1, 3],)),
1065+
((0, 1, 3), ([0, 1, 3],)),
1066+
),
1067+
)
1068+
def test_normalize_block_indexing_1d(flatblocks, expected):
1069+
nblocks = 5
1070+
array = dask.array.ones((nblocks,), chunks=(1,))
1071+
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1072+
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1073+
assert_equal_tuple(expected, actual)
1074+
1075+
1076+
@requires_dask
1077+
@pytest.mark.parametrize(
1078+
"flatblocks, expected",
1079+
(
1080+
((0, 1, 2, 3, 4), (0, slice(None))),
1081+
((1, 2, 3), (0, slice(1, 4))),
1082+
((1, 3), (0, [1, 3])),
1083+
((0, 1, 3), (0, [0, 1, 3])),
1084+
(tuple(range(10)), (slice(0, 2), slice(None))),
1085+
((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])),
1086+
((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])),
1087+
),
1088+
)
1089+
def test_normalize_block_indexing_2d(flatblocks, expected):
1090+
nblocks = 5
1091+
ndim = 2
1092+
array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim)
1093+
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
1094+
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
1095+
assert_equal_tuple(expected, actual)
1096+
1097+
1098+
@requires_dask
1099+
def test_subset_block_passthrough():
1100+
# full slice pass through
1101+
array = dask.array.ones((5,), chunks=(1,))
1102+
subset = subset_to_blocks(array, np.arange(5))
1103+
assert subset.name == array.name
1104+
1105+
array = dask.array.ones((5, 5), chunks=1)
1106+
subset = subset_to_blocks(array, np.arange(25))
1107+
assert subset.name == array.name
1108+
1109+
1110+
@requires_dask
1111+
@pytest.mark.parametrize(
1112+
"flatblocks, expectidx",
1113+
[
1114+
(np.arange(10), (slice(2), slice(None))),
1115+
(np.arange(8), (slice(2), slice(None))),
1116+
([0, 10], ([0, 2], slice(1))),
1117+
([0, 7], (slice(2), [0, 2])),
1118+
([0, 7, 9], (slice(2), [0, 2, 4])),
1119+
([0, 6, 12, 14], (slice(3), [0, 1, 2, 4])),
1120+
([0, 12, 14, 19], np.ix_([0, 2, 3], [0, 2, 4])),
1121+
],
1122+
)
1123+
def test_subset_block_2d(flatblocks, expectidx):
1124+
array = dask.array.from_array(np.arange(25).reshape((5, 5)), chunks=1)
1125+
subset = subset_to_blocks(array, flatblocks)
1126+
assert len(subset.dask.layers) == 2
1127+
assert_equal(subset, array.compute()[expectidx])

0 commit comments

Comments
 (0)