Skip to content

Commit 561378d

Browse files
authored
Add docs on method=None heuristics (#328)
* Update docs for `method=None` and spatial cohorts grouping. Closes #325 * Add docs on heuristics * tweaks
1 parent 3e0653f commit 561378d

File tree

6 files changed

+105
-29
lines changed

6 files changed

+105
-29
lines changed
15.5 KB
Loading

docs/diagrams/containment.png

12.8 KB
Loading
Loading

docs/diagrams/nwm-cohorts.png

149 KB
Loading

docs/source/implementation.md

+98-18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
format_name: myst
5+
kernelspec:
6+
display_name: Python 3
7+
name: python3
8+
---
9+
110
(algorithms)=
211

312
# Parallel Algorithms
@@ -7,10 +16,14 @@
716
can be hard. Performance strongly depends on how the groups are distributed amongst the blocks of an array.
817

918
`flox` implements 4 strategies for grouped reductions, each is appropriate for a particular distribution of groups
10-
among the blocks of a dask array. Switch between the various strategies by passing `method`
11-
and/or `reindex` to either {py:func}`flox.groupby_reduce` or {py:func}`flox.xarray.xarray_reduce`.
19+
among the blocks of a dask array.
20+
21+
```{tip}
22+
By default, `flox >= 0.9.0` will use [heuristics](method-heuristics) to choose a `method`.
23+
```
1224

13-
Your options are:
25+
Switch between the various strategies by passing `method` and/or `reindex` to either {py:func}`flox.groupby_reduce`
26+
or {py:func}`flox.xarray.xarray_reduce`. Your options are:
1427

1528
1. [`method="map-reduce"` with `reindex=False`](map-reindex-false)
1629
1. [`method="map-reduce"` with `reindex=True`](map-reindex-True)
@@ -20,18 +33,17 @@ Your options are:
2033
The most appropriate strategy for your problem will depend on the chunking of your dataset,
2134
and the distribution of group labels across those chunks.
2235

23-
```{tip}
2436
Currently these strategies are implemented for dask. We would like to generalize to other parallel array types
2537
as appropriate (e.g. Ramba, cubed, arkouda). Please open an issue to discuss if you are interested.
26-
```
2738

2839
(xarray-split)=
2940

30-
## Background: Xarray's current GroupBy strategy
41+
## Background
3142

32-
Xarray's current strategy is to find all unique group labels, index out each group,
33-
and then apply the reduction operation. Note that this only works if we know the group
34-
labels (i.e. you cannot use this strategy to group by a dask array).
43+
Without `flox` installed, Xarray's GroupBy strategy is to find all unique group labels,
44+
index out each group, and then apply the reduction operation. Note that this only works
45+
if we know the group labels (i.e. you cannot use this strategy to group by a dask array),
46+
and is basically an unvectorized slow for-loop over groups.
3547

3648
Schematically, this looks like (colors indicate group labels; separated groups of colors
3749
indicate different blocks of an array):
@@ -208,23 +220,91 @@ One annoyance is that if the chunksize doesn't evenly divide the number of group
208220
Consider our earlier example, `groupby("time.month")` with monthly frequency data and chunksize of 4 along `time`.
209221
![cohorts-schematic](/../diagrams/cohorts-month-chunk4.png)
210222

223+
```{code-cell}
224+
import flox
225+
import numpy as np
226+
227+
labels = np.tile(np.arange(12), 12)
228+
chunks = (tuple(np.repeat(4, labels.size // 4)),)
229+
```
230+
211231
`flox` can find these cohorts, below it identifies the cohorts with labels `1,2,3,4`; `5,6,7,8`, and `9,10,11,12`.
212232

213-
```python
214-
>>> flox.find_group_cohorts(labels, array.chunks[-1]).values()
215-
[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] # 3 cohorts
233+
```{code-cell}
234+
preferred_method, chunks_cohorts = flox.core.find_group_cohorts(labels, chunks)
235+
chunks_cohorts.values()
216236
```
217237

218238
Now consider `chunksize=5`.
219239
![cohorts-schematic](/../diagrams/cohorts-month-chunk5.png)
220240

221-
```python
222-
>>> flox.core.find_group_cohorts(labels, array.chunks[-1]).values()
223-
[[1], [2, 3], [4, 5], [6], [7, 8], [9, 10], [11], [12]] # 8 cohorts
241+
```{code-cell}
242+
labels = np.tile(np.arange(12), 12)
243+
chunks = (tuple(np.repeat(5, labels.size // 5)) + (4,),)
244+
preferred_method, chunks_cohorts = flox.core.find_group_cohorts(labels, chunks, merge=True)
245+
chunks_cohorts.values()
224246
```
225247

226-
We find 8 cohorts (note the original xarray strategy is equivalent to constructing 12 cohorts).
227-
In this case, it seems to better to rechunk to a size of `4` along `time`.
228-
If you have ideas for improving this case, please open an issue.
248+
We find 7 cohorts (note the original xarray strategy is equivalent to constructing 12 cohorts).
249+
In this case, it seems to better to rechunk to a size of `4` (or `6`) along `time`.
250+
251+
Indeed flox's heuristics think `"map-reduce"` is better for this case:
252+
253+
```{code-cell}
254+
preferred_method
255+
```
229256

230257
### Example : spatial grouping
258+
259+
Spatial groupings are particularly interesting for the `"cohorts"` strategy. Consider the problem of computing county-level
260+
aggregated statistics ([example blog post](https://xarray.dev/blog/flox)). There are ~3100 groups (counties), each marked by
261+
a different color. There are ~2300 chunks of size (350, 350) in (lat, lon). Many groups are contained to a small number of chunks:
262+
see left panel where the grid lines mark chunk boundaries.
263+
264+
![cohorts-schematic](/../diagrams/nwm-cohorts.png)
265+
266+
This seems like a good fit for `'cohorts'`: to get the answer for a county in the Northwest US, we needn't look at values
267+
for the southwest US. How do we decide that automatically for the user?
268+
269+
(method-heuristics)=
270+
271+
## Heuristics
272+
273+
`flox >=0.9` will automatically choose `method` for you. To do so, we need to detect how each group
274+
label is distributed across the chunks of the array; and the degree to which the chunk distribution for a particular
275+
label overlaps with all other labels. The algorithm is as follows.
276+
277+
1. First determine which labels are present in each chunk. The distribution of labels across chunks
278+
is represented internally as a 2D boolean sparse array `S[chunks, labels]`. `S[i, j] = 1` when
279+
label `j` is present in chunk `i`.
280+
281+
1. Then we look for patterns in `S` to decide if we can use `"blockwise"`. The dark color cells are `1` at that
282+
cell in `S`.
283+
![bitmask-patterns](/../diagrams/bitmask-patterns-perfect.png)
284+
285+
- On the left, is a monthly grouping for a monthly time series with chunk size 4. There are 3 non-overlapping cohorts so
286+
`method="cohorts"` is perfect.
287+
- On the right, is a resampling problem of a daily time series with chunk size 10 to 5-daily frequency. Two 5-day periods
288+
are exactly contained in one chunk, so `method="blockwise"` is perfect.
289+
290+
1. The metric used for determining the degree of overlap between the chunks occupied by different labels is
291+
[containment](http://ekzhu.com/datasketch/lshensemble.html). For each label `i` we can quickly compute containment against
292+
all other labels `j` as `C = S.T @ S / number_chunks_per_label`. Here is `C` for a range of chunk sizes from 1 to 12, for computing
293+
the monthly mean of a monthly time series problem, \[the title on each image is `(chunk size, sparsity)`\].
294+
295+
```python
296+
chunks = np.arange(1, 13)
297+
labels = np.tile(np.arange(1, 13), 30)
298+
```
299+
300+
![cohorts-schematic](/../diagrams/containment.png)
301+
302+
1. To choose between `"map-reduce"` and `"cohorts"`, we need a summary measure of the degree to which the labels overlap with
303+
each other. We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
304+
When sparsity > 0.6, we choose `"map-reduce"` since there is decent overlap between (any) cohorts. Otherwise we use `"cohorts"`.
305+
306+
Cool, isn't it?!
307+
308+
For reference here is `S` and `C` for the US county groupby problem:
309+
![county-bitmask](/../diagrams/counties-bitmask-containment.png)
310+
The sparsity of `C` is 0.006, so `"cohorts"` seems a good strategy here.

flox/visualize.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -139,35 +139,31 @@ def visualize_cohorts_2d(by, chunks):
139139
assert by.ndim == 2
140140
print("finding cohorts...")
141141
chunks = [chunks[ax] for ax in range(-by.ndim, 0)]
142-
before_merged = find_group_cohorts(by, chunks, merge=False)
143-
merged = find_group_cohorts(by, chunks, merge=True)
142+
_, chunks_cohorts = find_group_cohorts(by, chunks)
144143
print("finished cohorts...")
145144

146145
xticks = np.cumsum(chunks[-1])
147146
yticks = np.cumsum(chunks[-2])
148147

149-
f, ax = plt.subplots(1, 3, constrained_layout=True, sharex=False, sharey=False)
148+
f, ax = plt.subplots(1, 2, constrained_layout=True, sharex=False, sharey=False)
150149
ax = ax.ravel()
151150
# ax[1].set_visible(False)
152151
# ax = ax[[0, 2, 3]]
153152

154153
ngroups = len(_unique(by))
155154
h0 = ax[0].imshow(by, vmin=0, cmap=get_colormap(ngroups))
156-
h1 = _visualize_cohorts(chunks, before_merged, ax=ax[1])
157-
h2 = _visualize_cohorts(chunks, merged, ax=ax[2])
155+
h2 = _visualize_cohorts(chunks, chunks_cohorts, ax=ax[1])
158156

159-
for axx in ax:
160-
axx.grid(True, which="both")
157+
ax[0].grid(True, which="both")
161158
for axx in ax[:1]:
162159
axx.set_xticks(xticks)
163160
axx.set_yticks(yticks)
164-
for h, axx in zip([h0, h1, h2], ax):
161+
for h, axx in zip([h0, h2], ax):
165162
f.colorbar(h, ax=axx, orientation="horizontal")
166163

167164
ax[0].set_title(f"by: {ngroups} groups")
168-
ax[1].set_title(f"{len(before_merged)} cohorts")
169-
ax[2].set_title(f"{len(merged)} merged cohorts")
170-
f.set_size_inches((12, 6))
165+
ax[1].set_title(f"{len(chunks_cohorts)} cohorts")
166+
f.set_size_inches((9, 6))
171167

172168

173169
def _visualize_cohorts(chunks, cohorts, ax=None):

0 commit comments

Comments
 (0)