Skip to content

Commit f44738e

Browse files
authored
Fix race condition in factorize (#196)
* Fix race condition in factorize with multiple groupers, multiple reductions. * Add another reduction to see if that triggers it more * Fix * actuallly fix
1 parent e3fc5f0 commit f44738e

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

flox/core.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,12 @@ def factorize_(
446446
for groupvar, expect in zip(by, expected_groups):
447447
flat = groupvar.reshape(-1)
448448
if isinstance(expect, pd.RangeIndex):
449-
idx = flat
449+
# idx is a view of the original `by` aray
450+
# copy here so we don't have a race condition with the
451+
# group_idx[nanmask] = nan_sentinel assignment later
452+
# this is important in shared-memory parallelism with dask
453+
# TODO: figure out how to avoid this
454+
idx = flat.copy()
450455
found_groups.append(np.array(expect))
451456
# TODO: fix by using masked integers
452457
idx[idx > expect[-1]] = -1

tests/test_core.py

+27
Original file line numberDiff line numberDiff line change
@@ -1282,3 +1282,30 @@ def test_1d_blockwise_sort_optimization():
12821282
array, time.dt.dayofyear.values[::-1], sort=False, method="blockwise", func="count"
12831283
)
12841284
assert all("getitem" not in k for k in actual.dask.layers)
1285+
1286+
1287+
@requires_dask
1288+
def test_negative_index_factorize_race_condition():
1289+
# shape = (10, 2000)
1290+
# chunks = ((shape[0]-1,1), 10)
1291+
shape = (101, 174000)
1292+
chunks = ((101,), 8760)
1293+
eps = dask.array.random.random_sample(shape, chunks=chunks)
1294+
N2 = dask.array.random.random_sample(shape, chunks=chunks)
1295+
S2 = dask.array.random.random_sample(shape, chunks=chunks)
1296+
1297+
bins = np.arange(-5, -2.05, 0.1)
1298+
func = ["mean", "count", "sum"]
1299+
1300+
out = [
1301+
groupby_reduce(
1302+
eps,
1303+
N2,
1304+
S2,
1305+
func=f,
1306+
expected_groups=(bins, bins),
1307+
isbin=(True, True),
1308+
)
1309+
for f in func
1310+
]
1311+
[dask.compute(out, scheduler="threads") for _ in range(5)]

0 commit comments

Comments
 (0)