Skip to content

Fix Dask handling for chunks with 1-wide dimension #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fast_array_utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def validate_axis(ndim: int, axis: int | None) -> None:
if axis is None:
return
if not isinstance(axis, int | np.integer): # pragma: no cover
msg = "axis must be integer or None."
msg = f"axis must be integer or None, not {axis=!r}."
raise TypeError(msg)
if axis == 0 and ndim == 1:
raise AxisError(axis, ndim, "use axis=None for 1D arrays")
Expand Down
48 changes: 43 additions & 5 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,56 @@ def mean_var(
# https://github.com/scverse/fast-array-utils/issues/52
@overload
def sum(
x: CpuArray | GpuArray | DiskArray, /, *, axis: None = None, dtype: DTypeLike | None = None
x: CpuArray | DiskArray,
/,
*,
axis: None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> np.number[Any]: ...
@overload
def sum(
x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
x: CpuArray | DiskArray,
/,
*,
axis: Literal[0, 1],
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any]: ...


@overload
def sum(
x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
x: GpuArray,
/,
*,
axis: None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: Literal[False] = False,
) -> np.number[Any]: ...
@overload
def sum(
x: GpuArray, /, *, axis: None, dtype: DTypeLike | None = None, keep_cupy_as_array: Literal[True]
) -> types.CupyArray: ...
@overload
def sum(
x: GpuArray,
/,
*,
axis: Literal[0, 1],
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray: ...


@overload
def sum(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
x: types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.DaskArray: ...


Expand All @@ -249,6 +286,7 @@ def sum(
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
"""Sum over both or one axis.

Expand Down Expand Up @@ -286,4 +324,4 @@ def sum(
from ._sum import sum_

validate_axis(x.ndim, axis)
return sum_(x, axis=axis, dtype=dtype)
return sum_(x, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)
135 changes: 100 additions & 35 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Literal, cast

import numpy as np
from numpy.exceptions import AxisError

from .. import types


if TYPE_CHECKING:
from typing import Any, Literal
from typing import Any, Literal, TypeAlias

from numpy.typing import DTypeLike, NDArray

from ..typing import CpuArray, DiskArray, GpuArray

ComplexAxis: TypeAlias = (
tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1, None]
)


@singledispatch
def sum_(
Expand All @@ -24,7 +29,9 @@ def sum_(
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(
Expand All @@ -37,16 +44,31 @@ def sum_(

@sum_.register(types.CupyArray | types.CupyCSMatrix) # type: ignore[call-overload,misc]
def _sum_cupy(
x: GpuArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
x: GpuArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", np.sum(x, axis=axis, dtype=dtype))
return cast("np.number[Any]", arr.get()[()]) if axis is None else arr.squeeze()
return (
cast("np.number[Any]", arr.get()[()])
if not keep_cupy_as_array and axis is None
else arr.squeeze()
)


@sum_.register(types.CSBase) # type: ignore[call-overload,misc]
def _sum_cs(
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
x: types.CSBase,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
import scipy.sparse as sp

if isinstance(x, types.CSMatrix):
Expand All @@ -59,49 +81,92 @@ def _sum_cs(

@sum_.register(types.DaskArray)
def _sum_dask(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
x: types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.DaskArray:
import dask.array as da

from . import sum

if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001
msg = "sum does not support numpy matrices"
raise TypeError(msg)

def sum_drop_keepdims(
a: CpuArray,
/,
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
) -> NDArray[Any] | types.CupyArray:
del keepdims
if a.ndim == 1:
axis = None
else:
match axis:
case (0, 1) | (1, 0):
axis = None
case (0 | 1 as n,):
axis = n
case tuple(): # pragma: no cover
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)
rv = sum(a, axis=axis, dtype=dtype)
shape = (1,) if a.ndim == 1 else (1, 1 if rv.shape == () else len(rv)) # type: ignore[arg-type]
return np.reshape(rv, shape)

if dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = np.zeros(1, dtype=x.dtype).sum().dtype

return da.reduction(
rv = da.reduction(
x,
sum_drop_keepdims, # type: ignore[arg-type]
partial(np.sum, dtype=dtype), # pyright: ignore[reportArgumentType]
sum_dask_inner, # type: ignore[arg-type]
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)

if axis is not None or (
isinstance(rv._meta, types.CupyArray) # noqa: SLF001
and keep_cupy_as_array
):
return rv

def to_scalar(a: types.CupyArray | NDArray[Any]) -> np.number[Any]:
if isinstance(a, types.CupyArray):
a = a.get()
return a.reshape(())[()] # type: ignore[return-value]

return rv.map_blocks(to_scalar, meta=x.dtype.type(0)) # type: ignore[arg-type]


def sum_dask_inner(
a: CpuArray | GpuArray,
/,
*,
axis: ComplexAxis = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
) -> NDArray[Any] | types.CupyArray:
from . import sum

axis = normalize_axis(axis, a.ndim)
rv = sum(a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,arg-type]
shape = get_shape(rv, axis=axis, keepdims=keepdims)
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))


def normalize_axis(axis: ComplexAxis, ndim: int) -> Literal[0, 1, None]:
"""Adapt `axis` parameter passed by Dask to what we support."""
match axis:
case int() | None:
pass
case (0 | 1,):
axis = axis[0]
case (0, 1) | (1, 0):
axis = None
case _: # pragma: no cover
raise AxisError(axis, ndim) # type: ignore[call-overload]
if axis == 0 and ndim == 1:
return None # dask’s aggregate doesn’t know we don’t accept `axis=0` for 1D arrays
return axis


def get_shape(
a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Literal[0, 1, None], keepdims: bool
) -> tuple[int] | tuple[int, int]:
"""Get the output shape of an axis-flattening operation."""
match keepdims, a.ndim:
case False, 0:
return (1,)
case True, 0:
return (1, 1)
case False, 1:
return (a.size,)
case True, 1:
assert axis is not None
return (1, a.size) if axis == 0 else (a.size, 1)
# pragma: no cover
msg = f"{keepdims=}, {type(a)}"
raise AssertionError(msg)
24 changes: 23 additions & 1 deletion tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:

@pytest.fixture
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype_in))
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in))
np_arr.flags.writeable = False
if ndim == 1:
np_arr = np_arr.flatten()
Expand Down Expand Up @@ -165,6 +165,28 @@ def test_sum(
np.testing.assert_array_equal(sum_, expected)


@pytest.mark.parametrize(
"data",
[
pytest.param([[1, 0], [3, 0], [5, 6]], id="3x2"),
pytest.param([[1, 2, 3], [4, 5, 6]], id="2x3"),
pytest.param([[1, 0], [0, 2]], id="2x2"),
],
)
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.array_type(Flags.Dask)
def test_sum_dask_shapes(
array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]
) -> None:
np_arr = np.array(data, dtype=np.float32)
arr = array_type(np_arr)
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute())
if isinstance(sum_, types.CupyArray):
sum_ = sum_.get()
np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_)


@pytest.mark.array_type(skip=ATS_SPARSE_DS)
def test_mean(
array_type: ArrayType[Array], axis: Literal[0, 1, None], np_arr: NDArray[DTypeIn]
Expand Down
3 changes: 3 additions & 0 deletions typings/cupy/_core/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ from numpy.typing import NDArray
class ndarray:
dtype: np.dtype[Any]
shape: tuple[int, ...]
size: int
ndim: int

# cupy-specific
def get(self) -> NDArray[Any]: ...

# operators
def __array__(self) -> NDArray[Any]: ...
def __len__(self) -> int: ...
def __getitem__( # never returns scalars
self, index: int | slice | EllipsisType | tuple[int | slice | EllipsisType | None, ...]
) -> Self: ...
Expand All @@ -28,6 +30,7 @@ class ndarray:
def all(self, axis: None = None) -> np.bool: ...
@overload
def all(self, axis: int) -> ndarray: ...
def reshape(self, shape: tuple[int, ...] | int) -> ndarray: ...
def squeeze(self, axis: int | None = None) -> Self: ...
def ravel(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...
def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...
Expand Down
2 changes: 2 additions & 0 deletions typings/dask/array/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Array:
# dask methods and attrs
_meta: _Array
blocks: BlockView
chunks: tuple[tuple[int, ...], ...]
chunksize: tuple[int, ...]

def compute(self) -> _Array: ...
def visualize(
Expand Down
Loading