1
1
from __future__ import annotations
2
2
3
- import copy
4
3
import itertools
5
4
import math
6
5
import operator
@@ -304,14 +303,15 @@ def invert(x) -> tuple[np.ndarray, ...]:
304
303
# If our dataset has chunksize one along the axis,
305
304
# then no merging is possible.
306
305
single_chunks = all (all (a == 1 for a in ac ) for ac in chunks )
307
-
308
- if not single_chunks and merge :
306
+ one_group_per_chunk = ( bitmask . sum ( axis = 1 ) == 1 ). all ()
307
+ if not one_group_per_chunk and not single_chunks and merge :
309
308
# First sort by number of chunks occupied by cohort
310
309
sorted_chunks_cohorts = dict (
311
310
sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
312
311
)
313
312
314
- items = tuple ((k , set (k ), v ) for k , v in sorted_chunks_cohorts .items () if k )
313
+ # precompute needed metrics for the quadratic loop below.
314
+ items = tuple ((k , len (k ), set (k ), v ) for k , v in sorted_chunks_cohorts .items () if k )
315
315
316
316
merged_cohorts = {}
317
317
merged_keys : set [tuple ] = set ()
@@ -320,21 +320,28 @@ def invert(x) -> tuple[np.ndarray, ...]:
320
320
# and then merge in cohorts that are present in a subset of those chunks
321
321
# I think this is suboptimal and must fail at some point.
322
322
# But it might work for most cases. There must be a better way...
323
- for idx , (k1 , set_k1 , v1 ) in enumerate (items ):
323
+ for idx , (k1 , len_k1 , set_k1 , v1 ) in enumerate (items ):
324
324
if k1 in merged_keys :
325
325
continue
326
- merged_cohorts [k1 ] = copy .deepcopy (v1 )
327
- for k2 , set_k2 , v2 in items [idx + 1 :]:
328
- if k2 not in merged_keys and set_k2 .issubset (set_k1 ):
329
- merged_cohorts [k1 ].extend (v2 )
330
- merged_keys .update ((k2 ,))
331
-
332
- # make sure each cohort is sorted after merging
333
- sorted_merged_cohorts = {k : sorted (v ) for k , v in merged_cohorts .items ()}
326
+ new_key = set_k1
327
+ new_value = v1
328
+ # iterate in reverse since we expect small cohorts
329
+ # to be most likely merged in to larger ones
330
+ for k2 , len_k2 , set_k2 , v2 in reversed (items [idx + 1 :]):
331
+ if k2 not in merged_keys :
332
+ if (len (set_k2 & new_key ) / len_k2 ) > 0.75 :
333
+ new_key |= set_k2
334
+ new_value += v2
335
+ merged_keys .update ((k2 ,))
336
+ sorted_ = sorted (new_value )
337
+ merged_cohorts [tuple (sorted (new_key ))] = sorted_
338
+ if idx == 0 and (len (sorted_ ) == nlabels ) and (np .array (sorted_ ) == ilabels ).all ():
339
+ break
340
+
334
341
# sort by first label in cohort
335
342
# This will help when sort=True (default)
336
343
# and we have to resort the dask array
337
- return dict (sorted (sorted_merged_cohorts .items (), key = lambda kv : kv [1 ][0 ]))
344
+ return dict (sorted (merged_cohorts .items (), key = lambda kv : kv [1 ][0 ]))
338
345
339
346
else :
340
347
return chunks_cohorts
0 commit comments