Skip to content

Commit e19c630

Browse files
Reduce memory footprint of culling P2P rechunking (#8845)
1 parent 4aeed40 commit e19c630

1 file changed

Lines changed: 28 additions & 27 deletions

File tree

distributed/shuffle/_rechunk.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -370,28 +370,31 @@ def cull(
370370
indices_to_keep = self._keys_to_indices(keys)
371371
_old_to_new = old_to_new(self.chunks_input, self.chunks)
372372

373-
culled_deps: defaultdict[Key, set[Key]] = defaultdict(set)
374-
for nindex in indices_to_keep:
375-
old_indices_per_axis = []
376-
keepmap[nindex] = True
377-
for index, new_axis in zip(nindex, _old_to_new):
378-
old_indices_per_axis.append(
379-
[old_chunk_index for old_chunk_index, _ in new_axis[index]]
380-
)
381-
for old_nindex in product(*old_indices_per_axis):
382-
culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex)
373+
for ndindex in indices_to_keep:
374+
keepmap[ndindex] = True
383375

384-
# Protect against mutations later on with frozenset
385-
frozen_deps = {
386-
output_task: frozenset(input_tasks)
387-
for output_task, input_tasks in culled_deps.items()
388-
}
376+
culled_deps = {}
377+
# Identify the individual partial rechunks
378+
for ndpartial in _split_partials(_old_to_new):
379+
# Cull partials for which we do not keep any output tasks
380+
if not np.any(keepmap[ndpartial.new]):
381+
continue
382+
383+
# Within partials, we have all-to-all communication.
384+
# Thus, all output tasks share the same input tasks.
385+
deps = frozenset(
386+
(self.name_input,) + ndindex
387+
for ndindex in _ndindices_of_slice(ndpartial.old)
388+
)
389+
390+
for ndindex in _ndindices_of_slice(ndpartial.new):
391+
culled_deps[(self.name,) + ndindex] = deps
389392

390393
if np.array_equal(keepmap, self.keepmap):
391-
return self, frozen_deps
394+
return self, culled_deps
392395
else:
393396
culled_layer = self._cull(keepmap)
394-
return culled_layer, frozen_deps
397+
return culled_layer, culled_deps
395398

396399
def _construct_graph(self) -> _T_LowLevelGraph:
397400
import numpy as np
@@ -695,14 +698,12 @@ def _slice_new_chunks_into_partials(
695698
return tuple(sliced_axes)
696699

697700

698-
def _partial_ndindex(ndslice: NDSlice) -> np.ndindex:
699-
import numpy as np
700-
701-
return np.ndindex(tuple(slice.stop - slice.start for slice in ndslice))
701+
def _ndindices_of_slice(ndslice: NDSlice) -> Iterator[NDIndex]:
702+
return product(*(range(slc.start, slc.stop) for slc in ndslice))
702703

703704

704-
def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
705-
return tuple(index + offset for index, offset in zip(partial_index, partial_offset))
705+
def _partial_index(global_index: NDIndex, partial_offset: NDIndex) -> NDIndex:
706+
return tuple(index - offset for index, offset in zip(global_index, partial_offset))
706707

707708

708709
def partial_concatenate(
@@ -802,8 +803,8 @@ def partial_rechunk(
802803
)
803804

804805
transfer_keys = []
805-
for partial_index in _partial_ndindex(ndpartial.old):
806-
global_index = _global_index(partial_index, old_partial_offset)
806+
for global_index in _ndindices_of_slice(ndpartial.old):
807+
partial_index = _partial_index(global_index, old_partial_offset)
807808

808809
input_key = (input_name,) + global_index
809810

@@ -822,8 +823,8 @@ def partial_rechunk(
822823
dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys)
823824

824825
new_partial_offset = tuple(axis.start for axis in ndpartial.new)
825-
for partial_index in _partial_ndindex(ndpartial.new):
826-
global_index = _global_index(partial_index, new_partial_offset)
826+
for global_index in _ndindices_of_slice(ndpartial.new):
827+
partial_index = _partial_index(global_index, new_partial_offset)
827828
if keepmap[global_index]:
828829
dsk[(unpack_group,) + global_index] = (
829830
rechunk_unpack,

0 commit comments

Comments
 (0)