Skip to content

Commit 4952fe9

Browse files
authored
method heuristics: Avoid dot product as much as possible (#347)
* Another `method` detection optimization * fix * silence warnings * silence one more warning * Even better shortcut * Update docs
1 parent 307899a commit 4952fe9

File tree

4 files changed

+49
-19
lines changed

4 files changed

+49
-19
lines changed

asv_bench/benchmarks/cohorts.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class ERA5Dataset:
9595
"""ERA5"""
9696

9797
def __init__(self, *args, **kwargs):
98-
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
98+
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="h"))
9999
self.axis = (-1,)
100100
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))
101101

@@ -143,7 +143,7 @@ class PerfectMonthly(Cohorts):
143143
"""Perfectly chunked for a "cohorts" monthly mean climatology"""
144144

145145
def setup(self, *args, **kwargs):
146-
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
146+
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="ME"))
147147
self.axis = (-1,)
148148
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
149149
self.by = self.time.dt.month.values - 1
@@ -164,7 +164,7 @@ def rechunk(self):
164164
class ERA5Google(Cohorts):
165165
def setup(self, *args, **kwargs):
166166
TIME = 900 # 92044 in Google ARCO ERA5
167-
self.time = pd.Series(pd.date_range("1959-01-01", freq="6H", periods=TIME))
167+
self.time = pd.Series(pd.date_range("1959-01-01", freq="6h", periods=TIME))
168168
self.axis = (2,)
169169
self.array = dask.array.ones((721, 1440, TIME), chunks=(-1, -1, 1))
170170
self.by = self.time.dt.day.values - 1
@@ -201,3 +201,12 @@ def setup(self, *args, **kwargs):
201201
self.time = pd.Series(index)
202202
self.by = self.time.dt.dayofyear.values - 1
203203
self.expected = pd.RangeIndex(self.by.max() + 1)
204+
205+
206+
class RandomBigArray(Cohorts):
207+
def setup(self, *args, **kwargs):
208+
M, N = 100_000, 20_000
209+
self.array = dask.array.random.normal(size=(M, N), chunks=(10_000, N // 5)).T
210+
self.by = np.random.choice(5_000, size=M)
211+
self.expected = pd.RangeIndex(5000)
212+
self.axis = (1,)

ci/benchmark.yml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ channels:
33
- conda-forge
44
dependencies:
55
- asv
6+
- build
67
- cachey
78
- dask-core
89
- numpy>=1.22

docs/source/implementation.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,10 @@ label overlaps with all other labels. The algorithm is as follows.
300300
![cohorts-schematic](/../diagrams/containment.png)
301301

302302
1. To choose between `"map-reduce"` and `"cohorts"`, we need a summary measure of the degree to which the labels overlap with
303-
each other. We use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
304-
When sparsity > 0.6, we choose `"map-reduce"` since there is decent overlap between (any) cohorts. Otherwise we use `"cohorts"`.
303+
each other. We can use _sparsity_ --- the number of non-zero elements in `C` divided by the number of elements in `C`, `C.nnz/C.size`.
304+
We use sparsity(`S`) as an approximation for the sparsity(`C`) to avoid a potentially expensive sparse matrix dot product when `S`
305+
isn't particularly sparse. When sparsity(`S`) > 0.4 (arbitrary), we choose `"map-reduce"` since there is decent overlap between
306+
(any) cohorts. Otherwise we use `"cohorts"`.
305307

306308
Cool, isn't it?!
307309

flox/core.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -363,37 +363,55 @@ def invert(x) -> tuple[np.ndarray, ...]:
363363
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
364364
return "cohorts", chunks_cohorts
365365

366-
# Containment = |Q & S| / |Q|
366+
# We'll use containment to measure degree of overlap between labels.
367+
# Containment C = |Q & S| / |Q|
367368
# - |X| is the cardinality of set X
368369
# - Q is the query set being tested
369370
# - S is the existing set
370-
# We'll use containment to measure degree of overlap between labels. The bitmask
371-
# matrix allows us to calculate this pretty efficiently.
372-
asfloat = bitmask.astype(float)
373-
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
374-
# makes it non-symmetric.
375-
containment = csr_array((asfloat.T @ asfloat) / chunks_per_label)
376-
377-
# The containment matrix is a measure of how much the labels overlap
378-
# with each other. We treat the sparsity = (nnz/size) as a summary measure of the net overlap.
371+
# The bitmask matrix S allows us to calculate this pretty efficiently using a dot product.
372+
# S.T @ S / chunks_per_label
373+
#
374+
# We treat the sparsity(C) = (nnz/size) as a summary measure of the net overlap.
379375
# 1. For high enough sparsity, there is a lot of overlap and we should use "map-reduce".
380376
# 2. When labels are uniformly distributed amongst all chunks
381377
# (and number of labels < chunk size), sparsity is 1.
382378
# 3. Time grouping cohorts (e.g. dayofyear) appear as lines in this matrix.
383379
# 4. When there are no overlaps at all between labels, containment is a block diagonal matrix
384380
# (approximately).
385-
MAX_SPARSITY_FOR_COHORTS = 0.6 # arbitrary
386-
sparsity = containment.nnz / math.prod(containment.shape)
381+
#
382+
# However computing S.T @ S can still be the slowest step, especially if S
383+
# is not particularly sparse. Empirically the sparsity( S.T @ S ) > min(1, 2 x sparsity(S)).
384+
# So we use sparsity(S) as a shortcut.
385+
MAX_SPARSITY_FOR_COHORTS = 0.4 # arbitrary
386+
sparsity = bitmask.nnz / math.prod(bitmask.shape)
387387
preferred_method: Literal["map-reduce"] | Literal["cohorts"]
388+
logger.debug(
389+
"sparsity of bitmask is {}, threshold is {}".format( # noqa
390+
sparsity, MAX_SPARSITY_FOR_COHORTS
391+
)
392+
)
388393
if sparsity > MAX_SPARSITY_FOR_COHORTS:
389-
logger.info("sparsity is {}".format(sparsity)) # noqa
390394
if not merge:
391-
logger.info("find_group_cohorts: merge=False, choosing 'map-reduce'")
395+
logger.info(
396+
"find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'".format( # noqa
397+
sparsity
398+
)
399+
)
392400
return "map-reduce", {}
393401
preferred_method = "map-reduce"
394402
else:
395403
preferred_method = "cohorts"
396404

405+
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
406+
# makes it non-symmetric.
407+
asfloat = bitmask.astype(float)
408+
containment = csr_array(asfloat.T @ asfloat / chunks_per_label)
409+
410+
logger.debug(
411+
"sparsity of containment matrix is {}".format( # noqa
412+
containment.nnz / math.prod(containment.shape)
413+
)
414+
)
397415
# Use a threshold to force some merging. We do not use the filtered
398416
# containment matrix for estimating "sparsity" because it is a bit
399417
# hard to reason about.

0 commit comments

Comments
 (0)