9
9
from collections import namedtuple
10
10
from collections .abc import Sequence
11
11
from functools import partial , reduce
12
+ from itertools import product
12
13
from numbers import Integral
13
14
from typing import (
14
15
TYPE_CHECKING ,
23
24
import numpy_groupies as npg
24
25
import pandas as pd
25
26
import toolz as tlz
27
+ from scipy .sparse import csc_array
26
28
27
29
from . import xrdtypes
28
30
from .aggregate_flox import _prepare_for_flox
@@ -203,6 +205,16 @@ def _unique(a: np.ndarray) -> np.ndarray:
203
205
return np .sort (pd .unique (a .reshape (- 1 )))
204
206
205
207
208
+ def slices_from_chunks (chunks ):
209
+ """slightly modified from dask.array.core.slices_from_chunks to be lazy"""
210
+ cumdims = [tlz .accumulate (operator .add , bds , 0 ) for bds in chunks ]
211
+ slices = (
212
+ (slice (s , s + dim ) for s , dim in zip (starts , shapes ))
213
+ for starts , shapes in zip (cumdims , chunks )
214
+ )
215
+ return product (* slices )
216
+
217
+
206
218
@memoize
207
219
def find_group_cohorts (labels , chunks , merge : bool = True ) -> dict :
208
220
"""
@@ -215,9 +227,10 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
215
227
Parameters
216
228
----------
217
229
labels : np.ndarray
218
- mD Array of group labels
230
+ mD Array of integer group codes, factorized so that -1
231
+ represents NaNs.
219
232
chunks : tuple
220
- nD array that is being reduced
233
+ chunks of the array being reduced
221
234
merge : bool, optional
222
235
Attempt to merge cohorts when one cohort's chunks are a subset
223
236
of another cohort's chunks.
@@ -227,33 +240,59 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
227
240
cohorts: dict_values
228
241
Iterable of cohorts
229
242
"""
230
- import dask
231
-
232
243
# To do this, we must have values in memory so casting to numpy should be safe
233
244
labels = np .asarray (labels )
234
245
235
- # Build an array with the shape of labels, but where every element is the "chunk number"
236
- # 1. First subset the array appropriately
237
- axis = range (- labels .ndim , 0 )
238
- # Easier to create a dask array and use the .blocks property
239
- array = dask .array .empty (tuple (sum (c ) for c in chunks ), chunks = chunks )
240
- labels = np .broadcast_to (labels , array .shape [- labels .ndim :])
241
-
242
- # Iterate over each block and create a new block of same shape with "chunk number"
243
- shape = tuple (array .blocks .shape [ax ] for ax in axis )
244
- # Use a numpy object array to enable assignment in the loop
245
- # TODO: is it possible to just use a nested list?
246
- # That is what we need for `np.block`
247
- blocks = np .empty (shape , dtype = object )
248
- array_chunks = tuple (np .array (c ) for c in array .chunks )
249
- for idx , blockindex in enumerate (np .ndindex (array .numblocks )):
250
- chunkshape = tuple (c [i ] for c , i in zip (array_chunks , blockindex ))
251
- blocks [blockindex ] = np .full (chunkshape , idx )
252
- which_chunk = np .block (blocks .tolist ()).reshape (- 1 )
253
-
254
- raveled = labels .reshape (- 1 )
255
- # these are chunks where a label is present
256
- label_chunks = pd .Series (which_chunk ).groupby (raveled ).unique ()
246
+ shape = tuple (sum (c ) for c in chunks )
247
+ nchunks = math .prod (len (c ) for c in chunks )
248
+
249
+ # assumes that `labels` are factorized
250
+ nlabels = labels .max () + 1
251
+
252
+ labels = np .broadcast_to (labels , shape [- labels .ndim :])
253
+
254
+ rows = []
255
+ cols = []
256
+ # Add one to handle the -1 sentinel value
257
+ label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
258
+ ilabels = np .arange (nlabels )
259
+ for idx , region in enumerate (slices_from_chunks (chunks )):
260
+ # This is a quite fast way to find unique integers, when we know how many there are
261
+ # inspired by a similar idea in numpy_groupies for first, last
262
+ # instead of explicitly finding uniques, repeatedly write True to the same location
263
+ subset = labels [region ]
264
+ # The reshape is not strictly necessary but is about 100ms faster on a test problem.
265
+ label_is_present [subset .reshape (- 1 )] = True
266
+ # skip the -1 sentinel by slicing
267
+ uniques = ilabels [label_is_present [:- 1 ]]
268
+ rows .append ([idx ] * len (uniques ))
269
+ cols .append (uniques )
270
+ label_is_present [:] = False
271
+ rows_array = np .concatenate (rows )
272
+ cols_array = np .concatenate (cols )
273
+ data = np .broadcast_to (np .array (1 , dtype = np .uint8 ), rows_array .shape )
274
+ bitmask = csc_array ((data , (rows_array , cols_array )), dtype = bool , shape = (nchunks , nlabels ))
275
+ label_chunks = {
276
+ lab : bitmask .indices [slice (bitmask .indptr [lab ], bitmask .indptr [lab + 1 ])]
277
+ for lab in range (nlabels )
278
+ }
279
+
280
+ ## numpy bitmask approach, faster than finding uniques, but lots of memory
281
+ # bitmask = np.zeros((nchunks, nlabels), dtype=bool)
282
+ # for idx, region in enumerate(slices_from_chunks(chunks)):
283
+ # bitmask[idx, labels[region]] = True
284
+ # bitmask = bitmask[:, :-1]
285
+ # chunk = np.arange(nchunks) # [:, np.newaxis] * bitmask
286
+ # label_chunks = {lab: chunk[bitmask[:, lab]] for lab in range(nlabels - 1)}
287
+
288
+ ## Pandas GroupBy approach, quite slow!
289
+ # which_chunk = np.empty(shape, dtype=np.int64)
290
+ # for idx, region in enumerate(slices_from_chunks(chunks)):
291
+ # which_chunk[region] = idx
292
+ # which_chunk = which_chunk.reshape(-1)
293
+ # raveled = labels.reshape(-1)
294
+ # # these are chunks where a label is present
295
+ # label_chunks = pd.Series(which_chunk).groupby(raveled).unique()
257
296
258
297
# These invert the label_chunks mapping so we know which labels occur together.
259
298
def invert (x ) -> tuple [np .ndarray , ...]:
@@ -264,33 +303,31 @@ def invert(x) -> tuple[np.ndarray, ...]:
264
303
265
304
# If our dataset has chunksize one along the axis,
266
305
# then no merging is possible.
267
- single_chunks = all (( ac == 1 ). all () for ac in array_chunks )
306
+ single_chunks = all (all ( a == 1 for a in ac ) for ac in chunks )
268
307
269
- if merge and not single_chunks :
308
+ if not single_chunks and merge :
270
309
# First sort by number of chunks occupied by cohort
271
310
sorted_chunks_cohorts = dict (
272
311
sorted (chunks_cohorts .items (), key = lambda kv : len (kv [0 ]), reverse = True )
273
312
)
274
313
275
- items = tuple (sorted_chunks_cohorts .items ())
314
+ items = tuple (( k , set ( k ), v ) for k , v in sorted_chunks_cohorts .items () if k )
276
315
277
316
merged_cohorts = {}
278
- merged_keys = []
317
+ merged_keys = set ()
279
318
280
319
# Now we iterate starting with the longest number of chunks,
281
320
# and then merge in cohorts that are present in a subset of those chunks
282
321
# I think this is suboptimal and must fail at some point.
283
322
# But it might work for most cases. There must be a better way...
284
- for idx , (k1 , v1 ) in enumerate (items ):
323
+ for idx , (k1 , set_k1 , v1 ) in enumerate (items ):
285
324
if k1 in merged_keys :
286
325
continue
287
326
merged_cohorts [k1 ] = copy .deepcopy (v1 )
288
- for k2 , v2 in items [idx + 1 :]:
289
- if k2 in merged_keys :
290
- continue
291
- if set (k2 ).issubset (set (k1 )):
327
+ for k2 , set_k2 , v2 in items [idx + 1 :]:
328
+ if k2 not in merged_keys and set_k2 .issubset (set_k1 ):
292
329
merged_cohorts [k1 ].extend (v2 )
293
- merged_keys .append ( k2 )
330
+ merged_keys .update (( k2 ,) )
294
331
295
332
# make sure each cohort is sorted after merging
296
333
sorted_merged_cohorts = {k : sorted (v ) for k , v in merged_cohorts .items ()}
@@ -1373,7 +1410,6 @@ def dask_groupby_agg(
1373
1410
1374
1411
inds = tuple (range (array .ndim ))
1375
1412
name = f"groupby_{ agg .name } "
1376
- token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1377
1413
1378
1414
if expected_groups is None and reindex :
1379
1415
expected_groups = _get_expected_groups (by , sort = sort )
@@ -1394,6 +1430,9 @@ def dask_groupby_agg(
1394
1430
by = dask .array .from_array (by , chunks = chunks )
1395
1431
_ , (array , by ) = dask .array .unify_chunks (array , inds , by , inds [- by .ndim :])
1396
1432
1433
+ # tokenize here since by has already been hashed if its numpy
1434
+ token = dask .base .tokenize (array , by , agg , expected_groups , axis )
1435
+
1397
1436
# preprocess the array:
1398
1437
# - for argreductions, this zips the index together with the array block
1399
1438
# - not necessary for blockwise with argreductions
@@ -1510,7 +1549,7 @@ def dask_groupby_agg(
1510
1549
index = pd .Index (cohort )
1511
1550
subset = subset_to_blocks (intermediate , blks , array .blocks .shape [- len (axis ) :])
1512
1551
reindexed = dask .array .map_blocks (
1513
- reindex_intermediates , subset , agg = agg , unique_groups = index , meta = subset ._meta
1552
+ reindex_intermediates , subset , agg , index , meta = subset ._meta
1514
1553
)
1515
1554
# now that we have reindexed, we can set reindex=True explicitlly
1516
1555
reduced_ .append (
0 commit comments