Skip to content

Commit e29172b

Browse files
committed
allow random int arrays
1 parent d30b2f2 commit e29172b

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

src/testing/fast_array_utils/_array_type.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import enum
77
from dataclasses import KW_ONLY, dataclass, field
8-
from functools import cached_property
8+
from functools import cached_property, partial
99
from importlib.metadata import version
1010
from typing import TYPE_CHECKING, Generic, TypeVar, cast
1111

@@ -42,6 +42,9 @@ def __call__(
4242

4343
_DTypeLikeFloat32 = np.dtype[np.float32] | type[np.float32]
4444
_DTypeLikeFloat64 = np.dtype[np.float64] | type[np.float64]
45+
_DTypeLikeInt32 = np.dtype[np.int32] | type[np.int32]
46+
_DTypeLikeIn64 = np.dtype[np.int64] | type[np.int64]
47+
_DTypeLikeNum = _DTypeLikeFloat32 | _DTypeLikeFloat64 | _DTypeLikeInt32 | _DTypeLikeIn64
4548
else:
4649
Arr = TypeVar("Arr")
4750
Inner = TypeVar("Inner")
@@ -155,7 +158,7 @@ def random(
155158
self,
156159
shape: tuple[int, int],
157160
*,
158-
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64 | None = None,
161+
dtype: _DTypeLikeNum | None = None,
159162
gen: np.random.Generator | None = None,
160163
# sparse only
161164
density: float | np.floating[Any] = 0.01,
@@ -165,7 +168,7 @@ def random(
165168

166169
match self.mod, self.name, self.inner:
167170
case "numpy", "ndarray", None:
168-
return cast("Arr", gen.random(shape, dtype=dtype or np.float64))
171+
return cast("Arr", random_array(shape, dtype=dtype, rng=gen))
169172
case "scipy.sparse", (
170173
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
171174
) as cls_name, None:
@@ -179,7 +182,7 @@ def random(
179182
),
180183
)
181184
case "cupy", "ndarray", None:
182-
return self(gen.random(shape, dtype=dtype or np.float64))
185+
return self(random_array(shape, dtype=dtype, rng=gen))
183186
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
184187
import cupy as cu
185188

@@ -363,12 +366,28 @@ def _to_cupy_sparse(
363366
return self.cls(x) # type: ignore[call-arg,arg-type, return-value]
364367

365368

369+
def random_array(
370+
shape: tuple[int, int],
371+
*,
372+
dtype: _DTypeLikeNum | None = None,
373+
rng: np.random.Generator | None = None,
374+
) -> Array:
375+
"""Create a random array."""
376+
rng = np.random.default_rng(rng)
377+
f = (
378+
partial(rng.integers, 0, 10_000)
379+
if dtype is not None and np.dtype(dtype).kind in "iu"
380+
else rng.random
381+
)
382+
return f(shape, dtype=dtype) # type: ignore[arg-type]
383+
384+
366385
def random_mat(
367386
shape: tuple[int, int],
368387
*,
369388
density: float | np.floating[Any] = 0.01,
370389
format: Literal["csr", "csc"] = "csr", # noqa: A002
371-
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64 | None = None,
390+
dtype: _DTypeLikeNum | None = None,
372391
container: Literal["array", "matrix"] = "array",
373392
rng: np.random.Generator | None = None,
374393
) -> types.CSBase:

tests/test_sparse.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from fast_array_utils.types import CSBase
2121
from testing.fast_array_utils import ArrayType
22-
from testing.fast_array_utils._array_type import _DTypeLikeFloat32, _DTypeLikeFloat64
22+
from testing.fast_array_utils._array_type import _DTypeLikeNum
2323

2424

2525
pytestmark = [pytest.mark.skipif(not find_spec("scipy"), reason="scipy not installed")]
@@ -48,9 +48,7 @@ def dtype(request: pytest.FixtureRequest) -> type[np.float32 | np.float64]:
4848
@pytest.mark.array_type(select=Flags.Sparse, skip=Flags.Dask | Flags.Disk | Flags.Gpu)
4949
@pytest.mark.parametrize("order", ["C", "F"])
5050
def test_to_dense(
51-
array_type: ArrayType[CSBase, None],
52-
order: Literal["C", "F"],
53-
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64,
51+
array_type: ArrayType[CSBase, None], order: Literal["C", "F"], dtype: _DTypeLikeNum
5452
) -> None:
5553
mat = array_type.random((10, 10), density=0.1, dtype=dtype)
5654
with WARNS_NUMBA if not find_spec("numba") else nullcontext():
@@ -68,7 +66,7 @@ def test_to_dense_benchmark(
6866
benchmark: BenchmarkFixture,
6967
array_type: ArrayType[CSBase, None],
7068
order: Literal["C", "F"],
71-
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64,
69+
dtype: _DTypeLikeNum,
7270
) -> None:
7371
mat = array_type.random((10_000, 10_000), dtype=dtype)
7472
to_dense(mat, order=order) # warmup: numba compile

tests/test_stats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_dask_constant_blocks(
199199
@pytest.mark.benchmark
200200
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
201201
@pytest.mark.parametrize("func", [stats.sum, stats.mean, stats.mean_var, stats.is_constant])
202-
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float
202+
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int32])
203203
def test_stats_benchmark(
204204
benchmark: BenchmarkFixture,
205205
func: BenchFun,

0 commit comments

Comments
 (0)