44
44
quantile_new_dims_func ,
45
45
)
46
46
from .cache import memoize
47
+ from .lib import ArrayLayer
47
48
from .xrutils import (
48
49
_contains_cftime_datetimes ,
49
50
_to_pytimedelta ,
72
73
from typing import Unpack
73
74
except (ModuleNotFoundError , ImportError ):
74
75
Unpack : Any # type: ignore[no-redef]
75
-
76
- import cubed .Array as CubedArray
77
- import dask .array .Array as DaskArray
78
- from dask .typing import Graph
76
+ from .types import CubedArray , DaskArray , Graph
79
77
80
78
T_DuckArray : TypeAlias = np .ndarray | DaskArray | CubedArray # Any ?
81
79
T_By : TypeAlias = T_DuckArray
@@ -1191,7 +1189,7 @@ def _aggregate(
1191
1189
agg : Aggregation ,
1192
1190
expected_groups : pd .Index | None ,
1193
1191
axis : T_Axes ,
1194
- keepdims ,
1192
+ keepdims : bool ,
1195
1193
fill_value : Any ,
1196
1194
reindex : bool ,
1197
1195
) -> FinalResultsDict :
@@ -1511,7 +1509,7 @@ def subset_to_blocks(
1511
1509
blkshape : tuple [int , ...] | None = None ,
1512
1510
reindexer = identity ,
1513
1511
chunks_as_array : tuple [np .ndarray , ...] | None = None ,
1514
- ) -> DaskArray :
1512
+ ) -> ArrayLayer :
1515
1513
"""
1516
1514
Advanced indexing of .blocks such that we always get a regular array back.
1517
1515
@@ -1525,10 +1523,8 @@ def subset_to_blocks(
1525
1523
-------
1526
1524
dask.array
1527
1525
"""
1528
- import dask .array
1529
1526
from dask .array .slicing import normalize_index
1530
1527
from dask .base import tokenize
1531
- from dask .highlevelgraph import HighLevelGraph
1532
1528
1533
1529
if blkshape is None :
1534
1530
blkshape = array .blocks .shape
@@ -1538,9 +1534,6 @@ def subset_to_blocks(
1538
1534
1539
1535
index = _normalize_indexes (array , flatblocks , blkshape )
1540
1536
1541
- if all (not isinstance (i , np .ndarray ) and i == slice (None ) for i in index ):
1542
- return dask .array .map_blocks (reindexer , array , meta = array ._meta )
1543
-
1544
1537
# These rest is copied from dask.array.core.py with slight modifications
1545
1538
index = normalize_index (index , array .numblocks )
1546
1539
index = tuple (slice (k , k + 1 ) if isinstance (k , Integral ) else k for k in index )
@@ -1553,10 +1546,7 @@ def subset_to_blocks(
1553
1546
1554
1547
keys = itertools .product (* (range (len (c )) for c in chunks ))
1555
1548
layer : Graph = {(name ,) + key : (reindexer , tuple (new_keys [key ].tolist ())) for key in keys }
1556
-
1557
- graph = HighLevelGraph .from_collections (name , layer , dependencies = [array ])
1558
-
1559
- return dask .array .Array (graph , name , chunks , meta = array )
1549
+ return ArrayLayer (layer = layer , chunks = chunks , name = name )
1560
1550
1561
1551
1562
1552
def _extract_unknown_groups (reduced , dtype ) -> tuple [DaskArray ]:
@@ -1613,6 +1603,9 @@ def dask_groupby_agg(
1613
1603
) -> tuple [DaskArray , tuple [np .ndarray | DaskArray ]]:
1614
1604
import dask .array
1615
1605
from dask .array .core import slices_from_chunks
1606
+ from dask .highlevelgraph import HighLevelGraph
1607
+
1608
+ from .dask_array_ops import _tree_reduce
1616
1609
1617
1610
# I think _tree_reduce expects this
1618
1611
assert isinstance (axis , Sequence )
@@ -1742,35 +1735,44 @@ def dask_groupby_agg(
1742
1735
assert chunks_cohorts
1743
1736
block_shape = array .blocks .shape [- len (axis ) :]
1744
1737
1745
- reduced_ = []
1738
+ out_name = f" { name } -reduce- { method } - { token } "
1746
1739
groups_ = []
1747
1740
chunks_as_array = tuple (np .array (c ) for c in array .chunks )
1748
- for blks , cohort in chunks_cohorts .items ():
1741
+ dsk : Graph = {}
1742
+ for icohort , (blks , cohort ) in enumerate (chunks_cohorts .items ()):
1749
1743
cohort_index = pd .Index (cohort )
1750
1744
reindexer = (
1751
1745
partial (reindex_intermediates , agg = agg , unique_groups = cohort_index )
1752
1746
if do_simple_combine
1753
1747
else identity
1754
1748
)
1755
- reindexed = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
1749
+ subset = subset_to_blocks (intermediate , blks , block_shape , reindexer , chunks_as_array )
1750
+ dsk |= subset .layer # type: ignore[operator]
1756
1751
# now that we have reindexed, we can set reindex=True explicitlly
1757
- reduced_ . append (
1758
- tree_reduce (
1759
- reindexed ,
1760
- combine = partial ( combine , agg = agg , reindex = do_simple_combine ) ,
1761
- aggregate = partial (
1762
- aggregate ,
1763
- expected_groups = cohort_index ,
1764
- reindex = do_simple_combine ,
1765
- ),
1766
- )
1752
+ _tree_reduce (
1753
+ subset ,
1754
+ out_dsk = dsk ,
1755
+ name = out_name ,
1756
+ block_index = icohort ,
1757
+ axis = axis ,
1758
+ combine = partial ( combine , agg = agg , reindex = do_simple_combine , keepdims = True ) ,
1759
+ aggregate = partial (
1760
+ aggregate , expected_groups = cohort_index , reindex = do_simple_combine , keepdims = True
1761
+ ),
1767
1762
)
1768
1763
# This is done because pandas promotes to 64-bit types when an Index is created
1769
1764
# So we use the index to generate the return value for consistency with "map-reduce"
1770
1765
# This is important on windows
1771
1766
groups_ .append (cohort_index .values )
1772
1767
1773
- reduced = dask .array .concatenate (reduced_ , axis = - 1 )
1768
+ graph = HighLevelGraph .from_collections (out_name , dsk , dependencies = [intermediate ])
1769
+
1770
+ out_chunks = list (array .chunks )
1771
+ out_chunks [axis [- 1 ]] = tuple (len (c ) for c in chunks_cohorts .values ())
1772
+ for ax in axis [:- 1 ]:
1773
+ out_chunks [ax ] = (1 ,)
1774
+ reduced = dask .array .Array (graph , out_name , out_chunks , meta = array ._meta )
1775
+
1774
1776
groups = (np .concatenate (groups_ ),)
1775
1777
group_chunks = (tuple (len (cohort ) for cohort in groups_ ),)
1776
1778
0 commit comments