Skip to content

Commit faf67eb

Browse files
committed
Add ignore_old_chunks for rechunk_for_cohorts
1 parent 59cd665 commit faf67eb

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

flox/core.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
217217
return chunks_cohorts.values()
218218

219219

220-
def rechunk_for_cohorts(array, axis, labels, force_new_chunk_at, chunksize=None):
220+
def rechunk_for_cohorts(
221+
array, axis, labels, force_new_chunk_at, chunksize=None, ignore_old_chunks=False, debug=False
222+
):
221223
"""
222224
Rechunks array so that each new chunk contains groups that always occur together.
223225
@@ -257,6 +259,9 @@ def rechunk_for_cohorts(array, axis, labels, force_new_chunk_at, chunksize=None)
257259
force_new_chunk_at = _atleast_1d(force_new_chunk_at)
258260
oldchunks = array.chunks[axis]
259261
oldbreaks = np.insert(np.cumsum(oldchunks), 0, 0)
262+
if debug:
263+
labels_at_breaks = labels[oldbreaks[:-1]]
264+
print(labels_at_breaks[:40])
260265

261266
isbreak = np.isin(labels, force_new_chunk_at)
262267
if not np.any(isbreak):
@@ -276,13 +281,19 @@ def rechunk_for_cohorts(array, axis, labels, force_new_chunk_at, chunksize=None)
276281
else:
277282
next_break_is_close = False
278283

279-
if idx in oldbreaks or (counter >= chunksize and not next_break_is_close):
284+
if (not ignore_old_chunks and idx in oldbreaks) or (
285+
counter >= chunksize and not next_break_is_close
286+
):
280287
divisions.append(idx)
281288
counter = 1
282289
continue
283290
counter += 1
284291

285292
divisions.append(len(labels))
293+
if debug:
294+
labels_at_breaks = labels[divisions[:-1]]
295+
print(labels_at_breaks[:40])
296+
286297
newchunks = tuple(np.diff(divisions))
287298
assert sum(newchunks) == len(labels)
288299

flox/xarray.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ def rechunk_for_cohorts(
393393
labels: DataArray,
394394
force_new_chunk_at,
395395
chunksize: int | None = None,
396+
ignore_old_chunks: bool = False,
397+
debug: bool = False,
396398
):
397399
"""
398400
Rechunks array so that each new chunk contains groups that always occur together.
@@ -428,6 +430,8 @@ def rechunk_for_cohorts(
428430
labels,
429431
force_new_chunk_at=force_new_chunk_at,
430432
chunksize=chunksize,
433+
ignore_old_chunks=ignore_old_chunks,
434+
debug=debug,
431435
)
432436

433437

0 commit comments

Comments
 (0)