55
66import enum
77from dataclasses import KW_ONLY , dataclass , field
8- from functools import cached_property
8+ from functools import cached_property , partial
99from importlib .metadata import version
1010from typing import TYPE_CHECKING , Generic , TypeVar , cast
1111
@@ -42,6 +42,9 @@ def __call__(
4242
4343 _DTypeLikeFloat32 = np .dtype [np .float32 ] | type [np .float32 ]
4444 _DTypeLikeFloat64 = np .dtype [np .float64 ] | type [np .float64 ]
45+ _DTypeLikeInt32 = np .dtype [np .int32 ] | type [np .int32 ]
46+ _DTypeLikeIn64 = np .dtype [np .int64 ] | type [np .int64 ]
47+ _DTypeLikeNum = _DTypeLikeFloat32 | _DTypeLikeFloat64 | _DTypeLikeInt32 | _DTypeLikeIn64
4548else :
4649 Arr = TypeVar ("Arr" )
4750 Inner = TypeVar ("Inner" )
@@ -155,7 +158,7 @@ def random(
155158 self ,
156159 shape : tuple [int , int ],
157160 * ,
158- dtype : _DTypeLikeFloat32 | _DTypeLikeFloat64 | None = None ,
161+ dtype : _DTypeLikeNum | None = None ,
159162 gen : np .random .Generator | None = None ,
160163 # sparse only
161164 density : float | np .floating [Any ] = 0.01 ,
@@ -165,7 +168,7 @@ def random(
165168
166169 match self .mod , self .name , self .inner :
167170 case "numpy" , "ndarray" , None :
168- return cast ("Arr" , gen . random (shape , dtype = dtype or np . float64 ))
171+ return cast ("Arr" , random_array (shape , dtype = dtype , rng = gen ))
169172 case "scipy.sparse" , (
170173 "csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
171174 ) as cls_name , None :
@@ -179,7 +182,7 @@ def random(
179182 ),
180183 )
181184 case "cupy" , "ndarray" , None :
182- return self (gen . random (shape , dtype = dtype or np . float64 ))
185+ return self (random_array (shape , dtype = dtype , rng = gen ))
183186 case "cupyx.scipy.sparse" , ("csr_matrix" | "csc_matrix" ) as cls_name , None :
184187 import cupy as cu
185188
@@ -363,12 +366,28 @@ def _to_cupy_sparse(
363366 return self .cls (x ) # type: ignore[call-arg,arg-type, return-value]
364367
365368
369+ def random_array (
370+ shape : tuple [int , int ],
371+ * ,
372+ dtype : _DTypeLikeNum | None = None ,
373+ rng : np .random .Generator | None = None ,
374+ ) -> Array :
375+ """Create a random array."""
376+ rng = np .random .default_rng (rng )
377+ f = (
378+ partial (rng .integers , 0 , 10_000 )
379+ if dtype is not None and np .dtype (dtype ).kind in "iu"
380+ else rng .random
381+ )
382+ return f (shape , dtype = dtype ) # type: ignore[arg-type]
383+
384+
366385def random_mat (
367386 shape : tuple [int , int ],
368387 * ,
369388 density : float | np .floating [Any ] = 0.01 ,
370389 format : Literal ["csr" , "csc" ] = "csr" , # noqa: A002
371- dtype : _DTypeLikeFloat32 | _DTypeLikeFloat64 | None = None ,
390+ dtype : _DTypeLikeNum | None = None ,
372391 container : Literal ["array" , "matrix" ] = "array" ,
373392 rng : np .random .Generator | None = None ,
374393) -> types .CSBase :
0 commit comments