@@ -217,7 +217,9 @@ def find_group_cohorts(labels, chunks, merge=True, method="cohorts"):
217217 return chunks_cohorts .values ()
218218
219219
220- def rechunk_for_cohorts (array , axis , labels , force_new_chunk_at , chunksize = None ):
220+ def rechunk_for_cohorts (
221+ array , axis , labels , force_new_chunk_at , chunksize = None , ignore_old_chunks = False , debug = False
222+ ):
221223 """
222224 Rechunks array so that each new chunk contains groups that always occur together.
223225
@@ -257,6 +259,9 @@ def rechunk_for_cohorts(array, axis, labels, force_new_chunk_at, chunksize=None)
257259 force_new_chunk_at = _atleast_1d (force_new_chunk_at )
258260 oldchunks = array .chunks [axis ]
259261 oldbreaks = np .insert (np .cumsum (oldchunks ), 0 , 0 )
262+ if debug :
263+ labels_at_breaks = labels [oldbreaks [:- 1 ]]
264+ print (labels_at_breaks [:40 ])
260265
261266 isbreak = np .isin (labels , force_new_chunk_at )
262267 if not np .any (isbreak ):
@@ -276,13 +281,19 @@ def rechunk_for_cohorts(array, axis, labels, force_new_chunk_at, chunksize=None)
276281 else :
277282 next_break_is_close = False
278283
279- if idx in oldbreaks or (counter >= chunksize and not next_break_is_close ):
284+ if (not ignore_old_chunks and idx in oldbreaks ) or (
285+ counter >= chunksize and not next_break_is_close
286+ ):
280287 divisions .append (idx )
281288 counter = 1
282289 continue
283290 counter += 1
284291
285292 divisions .append (len (labels ))
293+ if debug :
294+ labels_at_breaks = labels [divisions [:- 1 ]]
295+ print (labels_at_breaks [:40 ])
296+
286297 newchunks = tuple (np .diff (divisions ))
287298 assert sum (newchunks ) == len (labels )
288299
0 commit comments