2
2
from __future__ import annotations
3
3
4
4
from functools import partial , singledispatch
5
- from typing import TYPE_CHECKING , cast
5
+ from typing import TYPE_CHECKING , Literal , cast
6
6
7
7
import numpy as np
8
+ from numpy .exceptions import AxisError
8
9
9
10
from .. import types
10
11
11
12
12
13
if TYPE_CHECKING :
13
- from typing import Any , Literal
14
+ from typing import Any , Literal , TypeAlias
14
15
15
16
from numpy .typing import DTypeLike , NDArray
16
17
17
18
from ..typing import CpuArray , DiskArray , GpuArray
18
19
20
+ ComplexAxis : TypeAlias = (
21
+ tuple [Literal [0 ], Literal [1 ]] | tuple [Literal [0 , 1 ]] | Literal [0 , 1 , None ]
22
+ )
23
+
19
24
20
25
@singledispatch
21
26
def sum_ (
@@ -24,7 +29,9 @@ def sum_(
24
29
* ,
25
30
axis : Literal [0 , 1 , None ] = None ,
26
31
dtype : DTypeLike | None = None ,
32
+ keep_cupy_as_array : bool = False ,
27
33
) -> NDArray [Any ] | np .number [Any ] | types .CupyArray | types .DaskArray :
34
+ del keep_cupy_as_array
28
35
if TYPE_CHECKING :
29
36
# these are never passed to this fallback function, but `singledispatch` wants them
30
37
assert not isinstance (
@@ -37,16 +44,31 @@ def sum_(
37
44
38
45
@sum_ .register (types .CupyArray | types .CupyCSMatrix ) # type: ignore[call-overload,misc]
39
46
def _sum_cupy (
40
- x : GpuArray , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
47
+ x : GpuArray ,
48
+ / ,
49
+ * ,
50
+ axis : Literal [0 , 1 , None ] = None ,
51
+ dtype : DTypeLike | None = None ,
52
+ keep_cupy_as_array : bool = False ,
41
53
) -> types .CupyArray | np .number [Any ]:
42
54
arr = cast ("types.CupyArray" , np .sum (x , axis = axis , dtype = dtype ))
43
- return cast ("np.number[Any]" , arr .get ()[()]) if axis is None else arr .squeeze ()
55
+ return (
56
+ cast ("np.number[Any]" , arr .get ()[()])
57
+ if not keep_cupy_as_array and axis is None
58
+ else arr .squeeze ()
59
+ )
44
60
45
61
46
62
@sum_ .register (types .CSBase ) # type: ignore[call-overload,misc]
47
63
def _sum_cs (
48
- x : types .CSBase , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
64
+ x : types .CSBase ,
65
+ / ,
66
+ * ,
67
+ axis : Literal [0 , 1 , None ] = None ,
68
+ dtype : DTypeLike | None = None ,
69
+ keep_cupy_as_array : bool = False ,
49
70
) -> NDArray [Any ] | np .number [Any ]:
71
+ del keep_cupy_as_array
50
72
import scipy .sparse as sp
51
73
52
74
if isinstance (x , types .CSMatrix ):
@@ -59,49 +81,92 @@ def _sum_cs(
59
81
60
82
@sum_ .register (types .DaskArray )
61
83
def _sum_dask (
62
- x : types .DaskArray , / , * , axis : Literal [0 , 1 , None ] = None , dtype : DTypeLike | None = None
84
+ x : types .DaskArray ,
85
+ / ,
86
+ * ,
87
+ axis : Literal [0 , 1 , None ] = None ,
88
+ dtype : DTypeLike | None = None ,
89
+ keep_cupy_as_array : bool = False ,
63
90
) -> types .DaskArray :
64
91
import dask .array as da
65
92
66
- from . import sum
67
-
68
93
if isinstance (x ._meta , np .matrix ): # pragma: no cover # noqa: SLF001
69
94
msg = "sum does not support numpy matrices"
70
95
raise TypeError (msg )
71
96
72
- def sum_drop_keepdims (
73
- a : CpuArray ,
74
- / ,
75
- * ,
76
- axis : tuple [Literal [0 ], Literal [1 ]] | Literal [0 , 1 , None ] = None ,
77
- dtype : DTypeLike | None = None ,
78
- keepdims : bool = False ,
79
- ) -> NDArray [Any ] | types .CupyArray :
80
- del keepdims
81
- if a .ndim == 1 :
82
- axis = None
83
- else :
84
- match axis :
85
- case (0 , 1 ) | (1 , 0 ):
86
- axis = None
87
- case (0 | 1 as n ,):
88
- axis = n
89
- case tuple (): # pragma: no cover
90
- msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got { axis } instead"
91
- raise ValueError (msg )
92
- rv = sum (a , axis = axis , dtype = dtype )
93
- shape = (1 ,) if a .ndim == 1 else (1 , 1 if rv .shape == () else len (rv )) # type: ignore[arg-type]
94
- return np .reshape (rv , shape )
95
-
96
97
if dtype is None :
97
98
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
98
99
dtype = np .zeros (1 , dtype = x .dtype ).sum ().dtype
99
100
100
- return da .reduction (
101
+ rv = da .reduction (
101
102
x ,
102
- sum_drop_keepdims , # type: ignore[arg-type]
103
- partial (np . sum , dtype = dtype ), # pyright: ignore[reportArgumentType]
103
+ sum_dask_inner , # type: ignore[arg-type]
104
+ partial (sum_dask_inner , dtype = dtype ), # pyright: ignore[reportArgumentType]
104
105
axis = axis ,
105
106
dtype = dtype ,
106
107
meta = np .array ([], dtype = dtype ),
107
108
)
109
+
110
+ if axis is not None or (
111
+ isinstance (rv ._meta , types .CupyArray ) # noqa: SLF001
112
+ and keep_cupy_as_array
113
+ ):
114
+ return rv
115
+
116
+ def to_scalar (a : types .CupyArray | NDArray [Any ]) -> np .number [Any ]:
117
+ if isinstance (a , types .CupyArray ):
118
+ a = a .get ()
119
+ return a .reshape (())[()] # type: ignore[return-value]
120
+
121
+ return rv .map_blocks (to_scalar , meta = x .dtype .type (0 )) # type: ignore[arg-type]
122
+
123
+
124
+ def sum_dask_inner (
125
+ a : CpuArray | GpuArray ,
126
+ / ,
127
+ * ,
128
+ axis : ComplexAxis = None ,
129
+ dtype : DTypeLike | None = None ,
130
+ keepdims : bool = False ,
131
+ ) -> NDArray [Any ] | types .CupyArray :
132
+ from . import sum
133
+
134
+ axis = normalize_axis (axis , a .ndim )
135
+ rv = sum (a , axis = axis , dtype = dtype , keep_cupy_as_array = True ) # type: ignore[misc,arg-type]
136
+ shape = get_shape (rv , axis = axis , keepdims = keepdims )
137
+ return cast ("NDArray[Any] | types.CupyArray" , rv .reshape (shape ))
138
+
139
+
140
+ def normalize_axis (axis : ComplexAxis , ndim : int ) -> Literal [0 , 1 , None ]:
141
+ """Adapt `axis` parameter passed by Dask to what we support."""
142
+ match axis :
143
+ case int () | None :
144
+ pass
145
+ case (0 | 1 ,):
146
+ axis = axis [0 ]
147
+ case (0 , 1 ) | (1 , 0 ):
148
+ axis = None
149
+ case _: # pragma: no cover
150
+ raise AxisError (axis , ndim ) # type: ignore[call-overload]
151
+ if axis == 0 and ndim == 1 :
152
+ return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays
153
+ return axis
154
+
155
+
156
+ def get_shape (
157
+ a : NDArray [Any ] | np .number [Any ] | types .CupyArray , * , axis : Literal [0 , 1 , None ], keepdims : bool
158
+ ) -> tuple [int ] | tuple [int , int ]:
159
+ """Get the output shape of an axis-flattening operation."""
160
+ match keepdims , a .ndim :
161
+ case False , 0 :
162
+ return (1 ,)
163
+ case True , 0 :
164
+ return (1 , 1 )
165
+ case False , 1 :
166
+ return (a .size ,)
167
+ case True , 1 :
168
+ assert axis is not None
169
+ return (1 , a .size ) if axis == 0 else (a .size , 1 )
170
+ # pragma: no cover
171
+ msg = f"{ keepdims = } , { type (a )} "
172
+ raise AssertionError (msg )
0 commit comments