@@ -370,28 +370,31 @@ def cull(
370370 indices_to_keep = self ._keys_to_indices (keys )
371371 _old_to_new = old_to_new (self .chunks_input , self .chunks )
372372
373- culled_deps : defaultdict [Key , set [Key ]] = defaultdict (set )
374- for nindex in indices_to_keep :
375- old_indices_per_axis = []
376- keepmap [nindex ] = True
377- for index , new_axis in zip (nindex , _old_to_new ):
378- old_indices_per_axis .append (
379- [old_chunk_index for old_chunk_index , _ in new_axis [index ]]
380- )
381- for old_nindex in product (* old_indices_per_axis ):
382- culled_deps [(self .name ,) + nindex ].add ((self .name_input ,) + old_nindex )
373+ for ndindex in indices_to_keep :
374+ keepmap [ndindex ] = True
383375
384- # Protect against mutations later on with frozenset
385- frozen_deps = {
386- output_task : frozenset (input_tasks )
387- for output_task , input_tasks in culled_deps .items ()
388- }
376+ culled_deps = {}
377+ # Identify the individual partial rechunks
378+ for ndpartial in _split_partials (_old_to_new ):
379+ # Cull partials for which we do not keep any output tasks
380+ if not np .any (keepmap [ndpartial .new ]):
381+ continue
382+
383+ # Within partials, we have all-to-all communication.
384+ # Thus, all output tasks share the same input tasks.
385+ deps = frozenset (
386+ (self .name_input ,) + ndindex
387+ for ndindex in _ndindices_of_slice (ndpartial .old )
388+ )
389+
390+ for ndindex in _ndindices_of_slice (ndpartial .new ):
391+ culled_deps [(self .name ,) + ndindex ] = deps
389392
390393 if np .array_equal (keepmap , self .keepmap ):
391- return self , frozen_deps
394+ return self , culled_deps
392395 else :
393396 culled_layer = self ._cull (keepmap )
394- return culled_layer , frozen_deps
397+ return culled_layer , culled_deps
395398
396399 def _construct_graph (self ) -> _T_LowLevelGraph :
397400 import numpy as np
@@ -695,14 +698,12 @@ def _slice_new_chunks_into_partials(
695698 return tuple (sliced_axes )
696699
697700
698- def _partial_ndindex (ndslice : NDSlice ) -> np .ndindex :
699- import numpy as np
700-
701- return np .ndindex (tuple (slice .stop - slice .start for slice in ndslice ))
701+ def _ndindices_of_slice (ndslice : NDSlice ) -> Iterator [NDIndex ]:
702+ return product (* (range (slc .start , slc .stop ) for slc in ndslice ))
702703
703704
704- def _global_index ( partial_index : NDIndex , partial_offset : NDIndex ) -> NDIndex :
705- return tuple (index + offset for index , offset in zip (partial_index , partial_offset ))
705+ def _partial_index ( global_index : NDIndex , partial_offset : NDIndex ) -> NDIndex :
706+ return tuple (index - offset for index , offset in zip (global_index , partial_offset ))
706707
707708
708709def partial_concatenate (
@@ -802,8 +803,8 @@ def partial_rechunk(
802803 )
803804
804805 transfer_keys = []
805- for partial_index in _partial_ndindex (ndpartial .old ):
806- global_index = _global_index ( partial_index , old_partial_offset )
806+ for global_index in _ndindices_of_slice (ndpartial .old ):
807+ partial_index = _partial_index ( global_index , old_partial_offset )
807808
808809 input_key = (input_name ,) + global_index
809810
@@ -822,8 +823,8 @@ def partial_rechunk(
822823 dsk [_barrier_key ] = (shuffle_barrier , partial_token , transfer_keys )
823824
824825 new_partial_offset = tuple (axis .start for axis in ndpartial .new )
825- for partial_index in _partial_ndindex (ndpartial .new ):
826- global_index = _global_index ( partial_index , new_partial_offset )
826+ for global_index in _ndindices_of_slice (ndpartial .new ):
827+ partial_index = _partial_index ( global_index , new_partial_offset )
827828 if keepmap [global_index ]:
828829 dsk [(unpack_group ,) + global_index ] = (
829830 rechunk_unpack ,
0 commit comments