Skip to content

Commit 630b0cb

Browse files
authored
Fix Dask handling for chunks with 1-wide dimension (#87)
1 parent c34cb6d commit 630b0cb

File tree

6 files changed

+172
-42
lines changed

6 files changed

+172
-42
lines changed

src/fast_array_utils/_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def validate_axis(ndim: int, axis: int | None) -> None:
99
if axis is None:
1010
return
1111
if not isinstance(axis, int | np.integer): # pragma: no cover
12-
msg = "axis must be integer or None."
12+
msg = f"axis must be integer or None, not {axis=!r}."
1313
raise TypeError(msg)
1414
if axis == 0 and ndim == 1:
1515
raise AxisError(axis, ndim, "use axis=None for 1D arrays")

src/fast_array_utils/stats/__init__.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,56 @@ def mean_var(
227227
# https://github.com/scverse/fast-array-utils/issues/52
228228
@overload
229229
def sum(
230-
x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None
230+
x: CpuArray | DiskArray,
231+
/,
232+
*,
233+
axis: None = None,
234+
dtype: DTypeLike | None = None,
235+
keep_cupy_as_array: bool = False,
231236
) -> np.number[Any]: ...
232237
@overload
233238
def sum(
234-
x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
239+
x: CpuArray | DiskArray,
240+
/,
241+
*,
242+
axis: Literal[0, 1],
243+
dtype: DTypeLike | None = None,
244+
keep_cupy_as_array: bool = False,
235245
) -> NDArray[Any]: ...
246+
247+
236248
@overload
237249
def sum(
238-
x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
250+
x: GpuArray,
251+
/,
252+
*,
253+
axis: None = None,
254+
dtype: DTypeLike | None = None,
255+
keep_cupy_as_array: Literal[False] = False,
256+
) -> np.number[Any]: ...
257+
@overload
258+
def sum(
259+
x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]
260+
) -> types.CupyArray: ...
261+
@overload
262+
def sum(
263+
x: GpuArray,
264+
/,
265+
*,
266+
axis: Literal[0, 1],
267+
dtype: DTypeLike | None = None,
268+
keep_cupy_as_array: bool = False,
239269
) -> types.CupyArray: ...
270+
271+
240272
@overload
241273
def sum(
242-
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
274+
x: types.DaskArray,
275+
/,
276+
*,
277+
axis: Literal[0, 1, None] = None,
278+
dtype: DTypeLike | None = None,
279+
keep_cupy_as_array: bool = False,
243280
) -> types.DaskArray: ...
244281

245282

@@ -249,6 +286,7 @@ def sum(
249286
*,
250287
axis: Literal[0, 1, None] = None,
251288
dtype: DTypeLike | None = None,
289+
keep_cupy_as_array: bool = False,
252290
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
253291
"""Sum over both or one axis.
254292
@@ -286,4 +324,4 @@ def sum(
286324
from ._sum import sum_
287325

288326
validate_axis(x.ndim, axis)
289-
return sum_(x, axis=axis, dtype=dtype)
327+
return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)

src/fast_array_utils/stats/_sum.py

+100-35
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
from __future__ import annotations
33

44
from functools import partial, singledispatch
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING, Literal, cast
66

77
import numpy as np
8+
from numpy.exceptions import AxisError
89

910
from .. import types
1011

1112

1213
if TYPE_CHECKING:
13-
from typing import Any, Literal
14+
from typing import Any, Literal, TypeAlias
1415

1516
from numpy.typing import DTypeLike, NDArray
1617

1718
from ..typing import CpuArray, DiskArray, GpuArray
1819

20+
ComplexAxis: TypeAlias = (
21+
tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None]
22+
)
23+
1924

2025
@singledispatch
2126
def sum_(
@@ -24,7 +29,9 @@ def sum_(
2429
*,
2530
axis: Literal[0, 1, None] = None,
2631
dtype: DTypeLike | None = None,
32+
keep_cupy_as_array: bool = False,
2733
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
34+
del keep_cupy_as_array
2835
if TYPE_CHECKING:
2936
# these are never passed to this fallback function, but `singledispatch` wants them
3037
assert not isinstance(
@@ -37,16 +44,31 @@ def sum_(
3744

3845
@sum_.register(types.CupyArray | types.CupyCSMatrix) # type: ignore[call-overload,misc]
3946
def _sum_cupy(
40-
x: GpuArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
47+
x: GpuArray,
48+
/,
49+
*,
50+
axis: Literal[0, 1, None] = None,
51+
dtype: DTypeLike | None = None,
52+
keep_cupy_as_array: bool = False,
4153
) -> types.CupyArray | np.number[Any]:
4254
arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype))
43-
return cast("np.number[Any]", arr.get()[()]) if axis is None else arr.squeeze()
55+
return (
56+
cast("np.number[Any]", arr.get()[()])
57+
if not keep_cupy_as_array and axis is None
58+
else arr.squeeze()
59+
)
4460

4561

4662
@sum_.register(types.CSBase) # type: ignore[call-overload,misc]
4763
def _sum_cs(
48-
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
64+
x: types.CSBase,
65+
/,
66+
*,
67+
axis: Literal[0, 1, None] = None,
68+
dtype: DTypeLike | None = None,
69+
keep_cupy_as_array: bool = False,
4970
) -> NDArray[Any] | np.number[Any]:
71+
del keep_cupy_as_array
5072
import scipy.sparse as sp
5173

5274
if isinstance(x, types.CSMatrix):
@@ -59,49 +81,92 @@ def _sum_cs(
5981

6082
@sum_.register(types.DaskArray)
6183
def _sum_dask(
62-
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
84+
x: types.DaskArray,
85+
/,
86+
*,
87+
axis: Literal[0, 1, None] = None,
88+
dtype: DTypeLike | None = None,
89+
keep_cupy_as_array: bool = False,
6390
) -> types.DaskArray:
6491
import dask.array as da
6592

66-
from . import sum
67-
6893
if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001
6994
msg = "sum does not support numpy matrices"
7095
raise TypeError(msg)
7196

72-
def sum_drop_keepdims(
73-
a: CpuArray,
74-
/,
75-
*,
76-
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None,
77-
dtype: DTypeLike | None = None,
78-
keepdims: bool = False,
79-
) -> NDArray[Any] | types.CupyArray:
80-
del keepdims
81-
if a.ndim == 1:
82-
axis = None
83-
else:
84-
match axis:
85-
case (0, 1) | (1, 0):
86-
axis = None
87-
case (0 | 1 as n,):
88-
axis = n
89-
case tuple(): # pragma: no cover
90-
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
91-
raise ValueError(msg)
92-
rv = sum(a, axis=axis, dtype=dtype)
93-
shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type]
94-
return np.reshape(rv, shape)
95-
9697
if dtype is None:
9798
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
9899
dtype = np.zeros(1, dtype=x.dtype).sum().dtype
99100

100-
return da.reduction(
101+
rv = da.reduction(
101102
x,
102-
sum_drop_keepdims, # type: ignore[arg-type]
103-
partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType]
103+
sum_dask_inner, # type: ignore[arg-type]
104+
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
104105
axis=axis,
105106
dtype=dtype,
106107
meta=np.array([], dtype=dtype),
107108
)
109+
110+
if axis is not None or (
111+
isinstance(rv._meta, types.CupyArray) # noqa: SLF001
112+
and keep_cupy_as_array
113+
):
114+
return rv
115+
116+
def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]:
117+
if isinstance(a, types.CupyArray):
118+
a = a.get()
119+
return a.reshape(())[()] # type: ignore[return-value]
120+
121+
return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type]
122+
123+
124+
def sum_dask_inner(
125+
a: CpuArray | GpuArray,
126+
/,
127+
*,
128+
axis: ComplexAxis = None,
129+
dtype: DTypeLike | None = None,
130+
keepdims: bool = False,
131+
) -> NDArray[Any] | types.CupyArray:
132+
from . import sum
133+
134+
axis = normalize_axis(axis, a.ndim)
135+
rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type]
136+
shape = get_shape(rv, axis=axis, keepdims=keepdims)
137+
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))
138+
139+
140+
def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]:
141+
"""Adapt `axis` parameter passed by Dask to what we support."""
142+
match axis:
143+
case int() | None:
144+
pass
145+
case (0 | 1,):
146+
axis = axis[0]
147+
case (0, 1) | (1, 0):
148+
axis = None
149+
case _: # pragma: no cover
150+
raise AxisError(axis, ndim) # type: ignore[call-overload]
151+
if axis == 0 and ndim == 1:
152+
return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays
153+
return axis
154+
155+
156+
def get_shape(
157+
a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool
158+
) -> tuple[int] | tuple[int, int]:
159+
"""Get the output shape of an axis-flattening operation."""
160+
match keepdims, a.ndim:
161+
case False, 0:
162+
return (1,)
163+
case True, 0:
164+
return (1, 1)
165+
case False, 1:
166+
return (a.size,)
167+
case True, 1:
168+
assert axis is not None
169+
return (1, a.size) if axis == 0 else (a.size, 1)
170+
# pragma: no cover
171+
msg = f"{keepdims=}, {type(a)}"
172+
raise AssertionError(msg)

tests/test_stats.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
9797

9898
@pytest.fixture
9999
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
100-
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
100+
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in))
101101
np_arr.flags.writeable = False
102102
if ndim == 1:
103103
np_arr = np_arr.flatten()
@@ -165,6 +165,28 @@ def test_sum(
165165
np.testing.assert_array_equal(sum_, expected)
166166

167167

168+
@pytest.mark.parametrize(
169+
"data",
170+
[
171+
pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"),
172+
pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"),
173+
pytest.param([[1, 0], [0, 2]], id="2x2"),
174+
],
175+
)
176+
@pytest.mark.parametrize("axis", [0, 1])
177+
@pytest.mark.array_type(Flags.Dask)
178+
def test_sum_dask_shapes(
179+
array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]
180+
) -> None:
181+
np_arr = np.array(data, dtype=np.float32)
182+
arr = array_type(np_arr)
183+
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
184+
sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute())
185+
if isinstance(sum_, types.CupyArray):
186+
sum_ = sum_.get()
187+
np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_)
188+
189+
168190
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
169191
def test_mean(
170192
array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn]

typings/cupy/_core/core.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ from numpy.typing import NDArray
88
class ndarray:
99
dtype: np.dtype[Any]
1010
shape: tuple[int, ...]
11+
size: int
1112
ndim: int
1213

1314
# cupy-specific
1415
def get(self) -> NDArray[Any]: ...
1516

1617
# operators
1718
def __array__(self) -> NDArray[Any]: ...
19+
def __len__(self) -> int: ...
1820
def __getitem__( # never returns scalars
1921
self, index: int | slice | EllipsisType | tuple[int | slice | EllipsisType | None, ...]
2022
) -> Self: ...
@@ -28,6 +30,7 @@ class ndarray:
2830
def all(self, axis: None = None) -> np.bool: ...
2931
@overload
3032
def all(self, axis: int) -> ndarray: ...
33+
def reshape(self, shape: tuple[int, ...] | int) -> ndarray: ...
3134
def squeeze(self, axis: int | None = None) -> Self: ...
3235
def ravel(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...
3336
def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...

typings/dask/array/core.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Array:
4444
# dask methods and attrs
4545
_meta: _Array
4646
blocks: BlockView
47+
chunks: tuple[tuple[int, ...], ...]
48+
chunksize: tuple[int, ...]
4749

4850
def compute(self) -> _Array: ...
4951
def visualize(

0 commit comments

Comments
 (0)