Skip to content

Commit 80171a3

Browse files
committed
WIP is-constant
1 parent ca0fff5 commit 80171a3

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed

src/fast_array_utils/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from __future__ import annotations
55

6+
from ._is_constant import is_constant
67
from ._sum import sum
78

89

9-
__all__ = ["sum"]
10+
__all__ = ["is_constant", "sum"]
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from functools import partial, singledispatch
5+
from numbers import Integral
6+
from typing import TYPE_CHECKING, overload
7+
8+
import numba
9+
import numpy as np
10+
11+
from ..types import CSBase, DaskArray, H5Dataset, ZarrArray
12+
13+
14+
if TYPE_CHECKING:
15+
from collections.abc import Callable
16+
from typing import Any, Literal, TypeAlias, TypeVar
17+
18+
from numpy.typing import NDArray
19+
20+
_Array: TypeAlias = NDArray[Any] | CSBase | H5Dataset | ZarrArray | DaskArray
21+
22+
C = TypeVar("C", bound=Callable[..., Any])
23+
24+
25+
@overload
26+
def is_constant(a: _Array, axis: None = None) -> bool: ...
27+
@overload
28+
def is_constant(a: _Array, axis: Literal[0, 1]) -> NDArray[np.bool_]: ...
29+
30+
31+
def is_constant(a: _Array, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool_]:
32+
"""Check whether values in array are constant.
33+
34+
Params
35+
------
36+
a
37+
Array to check
38+
axis
39+
Axis to reduce over.
40+
41+
42+
Returns:
43+
-------
44+
Boolean array, True values were constant.
45+
46+
Example:
47+
-------
48+
>>> a = np.array([[0, 1], [0, 0]])
49+
>>> a
50+
array([[0, 1],
51+
[0, 0]])
52+
>>> is_constant(a)
53+
False
54+
>>> is_constant(a, axis=0)
55+
array([ True, False])
56+
>>> is_constant(a, axis=1)
57+
array([False, True])
58+
59+
"""
60+
if axis is not None:
61+
if not isinstance(axis, Integral):
62+
msg = "axis must be integer or None."
63+
raise TypeError(msg)
64+
if axis not in (0, 1):
65+
msg = "We only support axis 0 and 1 at the moment"
66+
raise NotImplementedError(msg)
67+
68+
return _is_constant(a, axis)
69+
70+
71+
@singledispatch
72+
def _is_constant(a: _Array, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool_]:
73+
raise NotImplementedError
74+
75+
76+
@_is_constant.register(np.ndarray)
77+
@_is_constant.register(H5Dataset)
78+
@_is_constant.register(ZarrArray)
79+
def _(a: NDArray[Any], axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool_]:
80+
# Should eventually support nd, not now.
81+
match axis:
82+
case None:
83+
return bool((a == a.flat[0]).all())
84+
case 0:
85+
return _is_constant_rows(a.T)
86+
case 1:
87+
return _is_constant_rows(a)
88+
89+
90+
def _is_constant_rows(a: NDArray[Any]) -> NDArray[np.bool_]:
91+
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
92+
return (a == b).all(axis=1) # type: ignore[no-any-return]
93+
94+
95+
@_is_constant.register(CSBase) # type: ignore[call-overload,misc]
96+
def _(a: CSBase, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool_]:
97+
n_row, n_col = a.shape
98+
if axis is None:
99+
if len(a.data) == n_row * n_col:
100+
return is_constant(a.data)
101+
return (a.data == 0).all() # type: ignore[no-any-return]
102+
shape = (n_row, n_col) if axis == 1 else (n_col, n_row)
103+
match axis, a.format:
104+
case 0, "csr":
105+
a = a.T.tocsr()
106+
case 1, "csc":
107+
a = a.T.tocsc()
108+
return _is_constant_csr_rows(a.data, a.indptr, shape)
109+
110+
111+
@numba.njit(cache=True)
112+
def _is_constant_csr_rows(
113+
data: NDArray[np.number[Any]],
114+
indptr: NDArray[np.integer[Any]],
115+
shape: tuple[int, int],
116+
) -> NDArray[np.bool_]:
117+
n = len(indptr) - 1
118+
result = np.ones(n, dtype=np.bool_)
119+
for i in numba.prange(n):
120+
start = indptr[i]
121+
stop = indptr[i + 1]
122+
val = data[start] if stop - start == shape[1] else 0
123+
for j in range(start, stop):
124+
if data[j] != val:
125+
result[i] = False
126+
break
127+
return result
128+
129+
130+
@_is_constant.register(DaskArray)
131+
def _(a: DaskArray, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool_]:
132+
if TYPE_CHECKING:
133+
from dask.array.core import map_blocks
134+
else:
135+
from dask.array import map_blocks
136+
137+
if axis is None:
138+
v = a[tuple(0 for _ in range(a.ndim))].compute()
139+
return (a == v).all() # type: ignore[no-any-return]
140+
# TODO(flying-sheep): use overlapping blocks and reduction instead of `drop_axis` # noqa: TD003
141+
return map_blocks( # type: ignore[no-any-return,no-untyped-call]
142+
partial(is_constant, axis=axis),
143+
a,
144+
drop_axis=axis,
145+
meta=np.array([], dtype=a.dtype),
146+
)

tests/test_stats.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,32 @@ def test_sum(
7171
np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis, dtype=dtype_arg)) # type: ignore[arg-type]
7272

7373

74+
@pytest.mark.parametrize(
75+
("axis", "expected"),
76+
[
77+
pytest.param(None, False, id="None"),
78+
pytest.param(0, [True, True, False, False], id="0"),
79+
pytest.param(1, [False, False, True, True, False, True], id="1"),
80+
],
81+
)
82+
def test_is_constant_dask(
83+
to_array: ToArray, axis: Literal[0, 1, None], expected: bool | list[bool]
84+
) -> None:
85+
x_data = [
86+
[0, 0, 1, 1],
87+
[0, 0, 1, 1],
88+
[0, 0, 0, 0],
89+
[0, 0, 0, 0],
90+
[0, 0, 1, 0],
91+
[0, 0, 0, 0],
92+
]
93+
x = to_array(x_data)
94+
result = stats.is_constant(x, axis=axis)
95+
if isinstance(result, types.DaskArray):
96+
result = result.compute()
97+
np.testing.assert_array_equal(expected, result)
98+
99+
74100
@pytest.mark.benchmark
75101
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float
76102
def test_sum_benchmark(

0 commit comments

Comments
 (0)