Skip to content

Commit 666d45e

Browse files
authored
Use set containment instead of perfect subsets (#291)
* Use set containment instead of perfect subsets xref #180 Containment = |Q & S| / |Q| where - |X| is the cardinality of set X - Q is the query set being tested - S is the existing set https://ekzhu.com/datasketch/lshensemble.html#containment
1 parent 769db63 commit 666d45e

File tree

4 files changed

+53
-27
lines changed

4 files changed

+53
-27
lines changed

asv_bench/benchmarks/cohorts.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def track_num_layers(self):
4545
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
4646
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
4747
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
48+
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers]:
49+
f.repeat = 1 # type: ignore[attr-defined] # Lazy
50+
f.rounds = 1 # type: ignore[attr-defined] # Lazy
51+
f.number = 1 # type: ignore[attr-defined] # Lazy
4852

4953

5054
class NWMMidwest(Cohorts):
@@ -83,7 +87,7 @@ def setup(self, *args, **kwargs):
8387
class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
8488
def setup(self, *args, **kwargs):
8589
super().setup()
86-
super().rechunk()
90+
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 24))
8791

8892

8993
class ERA5MonthHour(ERA5Dataset, Cohorts):

asv_bench/benchmarks/reduce.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,17 @@ def time_reduce(self, func, expected_name, engine):
5959
expected_groups=expected_groups[expected_name],
6060
)
6161

62-
@skip_for_params(numbagg_skip)
63-
@parameterize({"func": funcs, "expected_name": expected_names, "engine": engines})
64-
def peakmem_reduce(self, func, expected_name, engine):
65-
flox.groupby_reduce(
66-
self.array,
67-
self.labels,
68-
func=func,
69-
engine=engine,
70-
axis=self.axis,
71-
expected_groups=expected_groups[expected_name],
72-
)
62+
# @skip_for_params(numbagg_skip)
63+
# @parameterize({"func": funcs, "expected_name": expected_names, "engine": engines})
64+
# def peakmem_reduce(self, func, expected_name, engine):
65+
# flox.groupby_reduce(
66+
# self.array,
67+
# self.labels,
68+
# func=func,
69+
# engine=engine,
70+
# axis=self.axis,
71+
# expected_groups=expected_groups[expected_name],
72+
# )
7373

7474

7575
class ChunkReduce1D(ChunkReduce):

flox/core.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import copy
43
import itertools
54
import math
65
import operator
@@ -304,14 +303,15 @@ def invert(x) -> tuple[np.ndarray, ...]:
304303
# If our dataset has chunksize one along the axis,
305304
# then no merging is possible.
306305
single_chunks = all(all(a == 1 for a in ac) for ac in chunks)
307-
308-
if not single_chunks and merge:
306+
one_group_per_chunk = (bitmask.sum(axis=1) == 1).all()
307+
if not one_group_per_chunk and not single_chunks and merge:
309308
# First sort by number of chunks occupied by cohort
310309
sorted_chunks_cohorts = dict(
311310
sorted(chunks_cohorts.items(), key=lambda kv: len(kv[0]), reverse=True)
312311
)
313312

314-
items = tuple((k, set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
313+
# precompute needed metrics for the quadratic loop below.
314+
items = tuple((k, len(k), set(k), v) for k, v in sorted_chunks_cohorts.items() if k)
315315

316316
merged_cohorts = {}
317317
merged_keys: set[tuple] = set()
@@ -320,21 +320,28 @@ def invert(x) -> tuple[np.ndarray, ...]:
320320
# and then merge in cohorts that are present in a subset of those chunks
321321
# I think this is suboptimal and must fail at some point.
322322
# But it might work for most cases. There must be a better way...
323-
for idx, (k1, set_k1, v1) in enumerate(items):
323+
for idx, (k1, len_k1, set_k1, v1) in enumerate(items):
324324
if k1 in merged_keys:
325325
continue
326-
merged_cohorts[k1] = copy.deepcopy(v1)
327-
for k2, set_k2, v2 in items[idx + 1 :]:
328-
if k2 not in merged_keys and set_k2.issubset(set_k1):
329-
merged_cohorts[k1].extend(v2)
330-
merged_keys.update((k2,))
331-
332-
# make sure each cohort is sorted after merging
333-
sorted_merged_cohorts = {k: sorted(v) for k, v in merged_cohorts.items()}
326+
new_key = set_k1
327+
new_value = v1
328+
# iterate in reverse since we expect small cohorts
329+
# to be most likely merged in to larger ones
330+
for k2, len_k2, set_k2, v2 in reversed(items[idx + 1 :]):
331+
if k2 not in merged_keys:
332+
if (len(set_k2 & new_key) / len_k2) > 0.75:
333+
new_key |= set_k2
334+
new_value += v2
335+
merged_keys.update((k2,))
336+
sorted_ = sorted(new_value)
337+
merged_cohorts[tuple(sorted(new_key))] = sorted_
338+
if idx == 0 and (len(sorted_) == nlabels) and (np.array(sorted_) == ilabels).all():
339+
break
340+
334341
# sort by first label in cohort
335342
# This will help when sort=True (default)
336343
# and we have to resort the dask array
337-
return dict(sorted(sorted_merged_cohorts.items(), key=lambda kv: kv[1][0]))
344+
return dict(sorted(merged_cohorts.items(), key=lambda kv: kv[1][0]))
338345

339346
else:
340347
return chunks_cohorts

tests/test_core.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -848,11 +848,26 @@ def test_rechunk_for_blockwise(inchunks, expected):
848848
],
849849
],
850850
)
851-
def test_find_group_cohorts(expected, labels, chunks, merge):
851+
def test_find_group_cohorts(expected, labels, chunks: tuple[int], merge: bool) -> None:
852852
actual = list(find_group_cohorts(labels, (chunks,), merge).values())
853853
assert actual == expected, (actual, expected)
854854

855855

856+
@pytest.mark.parametrize("chunksize", [12, 13, 14, 24, 36, 48, 72, 71])
857+
def test_verify_complex_cohorts(chunksize: int) -> None:
858+
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
859+
chunks = (chunksize,) * (len(time) // chunksize)
860+
by = np.array(time.dt.dayofyear.values)
861+
862+
if len(by) != sum(chunks):
863+
chunks += (len(by) - sum(chunks),)
864+
chunk_cohorts = find_group_cohorts(by - 1, (chunks,))
865+
chunks_ = np.sort(np.concatenate(tuple(chunk_cohorts.keys())))
866+
groups = np.sort(np.concatenate(tuple(chunk_cohorts.values())))
867+
assert_equal(np.unique(chunks_), np.arange(len(chunks), dtype=int))
868+
assert_equal(groups, np.arange(366, dtype=int))
869+
870+
856871
@requires_dask
857872
@pytest.mark.parametrize(
858873
"chunk_at,expected",

0 commit comments

Comments
 (0)