Skip to content

Commit 8f4bb6b

Browse files
gh-961: use rng_dispatcher throughout code (#962)
Co-authored-by: Connor Aird <c.aird@ucl.ac.uk>
1 parent bee3e55 commit 8f4bb6b

File tree

7 files changed

+72
-88
lines changed

7 files changed

+72
-88
lines changed

glass/_array_api_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from glass._types import AnyArray, FloatArray, UnifiedGenerator
3131

32+
SEED = 42
33+
3234

3335
class CompatibleBackendNotFoundError(Exception):
3436
"""
@@ -83,12 +85,14 @@ def default_xp() -> ModuleType:
8385
return import_numpy()
8486

8587

86-
def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
88+
def rng_dispatcher(*, seed: int | AArray = SEED, xp: ModuleType) -> UnifiedGenerator:
8789
"""
8890
Dispatch a random number generator based on the provided array's backend.
8991
9092
Parameters
9193
----------
94+
seed
95+
Seed for the random number generator.
9296
xp
9397
The array library backend to use for array operations.
9498
@@ -101,8 +105,6 @@ def rng_dispatcher(*, xp: ModuleType) -> UnifiedGenerator:
101105
NotImplementedError
102106
If the array backend is not supported.
103107
"""
104-
seed = 42
105-
106108
if xp.__name__ == "jax.numpy":
107109
import glass.jax # noqa: PLC0415
108110

@@ -128,10 +130,7 @@ class Generator:
128130

129131
__slots__ = ("axp", "nxp", "rng")
130132

131-
def __init__(
132-
self,
133-
seed: int | bool | AArray | None = None, # noqa: FBT001
134-
) -> None:
133+
def __init__(self, *, seed: int | AArray = SEED) -> None:
135134
"""
136135
Initialize the Generator.
137136
@@ -146,7 +145,7 @@ def __init__(
146145

147146
self.axp = array_api_strict
148147
self.nxp = numpy
149-
self.rng = self.nxp.random.default_rng(seed=seed)
148+
self.rng = rng_dispatcher(seed=seed, xp=self.nxp)
150149

151150
def random(
152151
self,
@@ -171,7 +170,7 @@ def random(
171170
Array of random floats.
172171
"""
173172
dtype = dtype if dtype is not None else self.nxp.float64
174-
return self.axp.asarray(self.rng.random(size, dtype, out)) # type: ignore[arg-type]
173+
return self.axp.asarray(self.rng.random(size, dtype, out)) # type: ignore[arg-type,call-arg]
175174

176175
def normal(
177176
self,
@@ -216,7 +215,7 @@ def poisson(
216215
-------
217216
Array of samples from the Poisson distribution.
218217
"""
219-
return self.axp.asarray(self.rng.poisson(lam, size))
218+
return self.axp.asarray(self.rng.poisson(lam, size)) # type: ignore[arg-type]
220219

221220
def standard_normal(
222221
self,
@@ -241,7 +240,7 @@ def standard_normal(
241240
Array of samples from the standard normal distribution.
242241
"""
243242
dtype = dtype if dtype is not None else self.nxp.float64
244-
return self.axp.asarray(self.rng.standard_normal(size, dtype, out)) # type: ignore[arg-type]
243+
return self.axp.asarray(self.rng.standard_normal(size, dtype, out)) # type: ignore[arg-type,call-arg]
245244

246245
def uniform(
247246
self,
@@ -265,7 +264,7 @@ def uniform(
265264
-------
266265
Array of samples from the uniform distribution.
267266
"""
268-
return self.axp.asarray(self.rng.uniform(low, high, size))
267+
return self.axp.asarray(self.rng.uniform(low, high, size)) # type: ignore[arg-type]
269268

270269

271270
class XPAdditions:

glass/fields.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,14 @@
2525
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
2626
from typing import Literal
2727

28-
from glass._types import AnyArray, ComplexArray, FloatArray, IntArray, T
28+
from glass._types import (
29+
AnyArray,
30+
ComplexArray,
31+
FloatArray,
32+
IntArray,
33+
T,
34+
UnifiedGenerator,
35+
)
2936

3037

3138
try:
@@ -318,7 +325,7 @@ def _generate_grf(
318325
nside: int,
319326
*,
320327
ncorr: int | None = None,
321-
rng: np.random.Generator | None = None,
328+
rng: UnifiedGenerator | None = None,
322329
) -> Generator[FloatArray]:
323330
"""
324331
Iteratively sample Gaussian random fields (internal use).
@@ -356,7 +363,7 @@ def _generate_grf(
356363
If all gls are empty.
357364
"""
358365
if rng is None:
359-
rng = np.random.default_rng(42)
366+
rng = _utils.rng_dispatcher(xp=np)
360367

361368
# number of gls and number of fields
362369
ngls = len(gls) # type: ignore[arg-type]
@@ -386,7 +393,7 @@ def _generate_grf(
386393
for j, a, s in conditional_dist:
387394
# standard normal random variates for alm
388395
# sample real and imaginary parts, then view as complex number
389-
rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64))
396+
rng.standard_normal(n * (n + 1), np.float64, z.view(np.float64)) # type: ignore[arg-type,call-arg]
390397

391398
# scale by standard deviation of the conditional distribution
392399
# variance is distributed over real and imaginary part

tests/benchmarks/test_fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def function_to_benchmark() -> list[Any]:
175175
generator = glass.fields._generate_grf(
176176
gls,
177177
nside,
178-
rng=urngb if use_rng else None, # type: ignore[arg-type]
178+
rng=urngb if use_rng else None,
179179
ncorr=ncorr,
180180
)
181181
return generator_consumer.consume(generator) # type: ignore[no-any-return]

tests/core/test_array_api_utils.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# check if available for testing
1414
HAVE_ARRAY_API_STRICT = importlib.util.find_spec("array_api_strict") is not None
1515
HAVE_JAX = importlib.util.find_spec("jax") is not None
16-
SEED = 42
1716

1817

1918
def test_rng_dispatcher_numpy() -> None:
@@ -39,61 +38,59 @@ def test_rng_dispatcher_array_api_strict() -> None:
3938

4039
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
4140
def test_init() -> None:
42-
rng = _utils.Generator(SEED)
41+
rng = _utils.Generator()
4342
assert isinstance(rng, _utils.Generator)
4443

4544

4645
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
4746
def test_random() -> None:
4847
import array_api_strict
49-
from array_api_strict._array_object import Array
5048

51-
rng = _utils.Generator(SEED)
49+
rng = _utils.rng_dispatcher(xp=array_api_strict)
5250
rvs = rng.random(size=10_000)
5351
assert rvs.shape == (10_000,)
5452
assert array_api_strict.min(rvs) >= 0.0
5553
assert array_api_strict.max(rvs) < 1.0
56-
assert isinstance(rvs, Array)
54+
assert isinstance(rvs, array_api_strict._array_object.Array)
5755

5856

5957
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
6058
def test_normal() -> None:
61-
from array_api_strict._array_object import Array
59+
import array_api_strict
6260

63-
rng = _utils.Generator(SEED)
61+
rng = _utils.rng_dispatcher(xp=array_api_strict)
6462
rvs = rng.normal(1, 2, size=10_000)
6563
assert rvs.shape == (10_000,)
66-
assert isinstance(rvs, Array)
64+
assert isinstance(rvs, array_api_strict._array_object.Array)
6765

6866

6967
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
7068
def test_standard_normal() -> None:
71-
from array_api_strict._array_object import Array
69+
import array_api_strict
7270

73-
rng = _utils.Generator(SEED)
71+
rng = _utils.rng_dispatcher(xp=array_api_strict)
7472
rvs = rng.standard_normal(size=10_000)
7573
assert rvs.shape == (10_000,)
76-
assert isinstance(rvs, Array)
74+
assert isinstance(rvs, array_api_strict._array_object.Array)
7775

7876

7977
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
8078
def test_poisson() -> None:
81-
from array_api_strict._array_object import Array
79+
import array_api_strict
8280

83-
rng = _utils.Generator(SEED)
81+
rng = _utils.rng_dispatcher(xp=array_api_strict)
8482
rvs = rng.poisson(lam=1, size=10_000)
8583
assert rvs.shape == (10_000,)
86-
assert isinstance(rvs, Array)
84+
assert isinstance(rvs, array_api_strict._array_object.Array)
8785

8886

8987
@pytest.mark.skipif(not HAVE_ARRAY_API_STRICT, reason="test requires array_api_strict")
9088
def test_uniform() -> None:
9189
import array_api_strict
92-
from array_api_strict._array_object import Array
9390

94-
rng = _utils.Generator(SEED)
91+
rng = _utils.rng_dispatcher(xp=array_api_strict)
9592
rvs = rng.uniform(size=10_000)
9693
assert rvs.shape == (10_000,)
9794
assert array_api_strict.min(rvs) >= 0.0
9895
assert array_api_strict.max(rvs) < 1.0
99-
assert isinstance(rvs, Array)
96+
assert isinstance(rvs, array_api_strict._array_object.Array)

tests/core/test_fields.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import glass
11+
import glass._array_api_utils as _utils
1112
import glass.fields
1213

1314
if TYPE_CHECKING:
@@ -386,20 +387,19 @@ def test_generate_grf(compare: type[Compare]) -> None:
386387
gls = [np.asarray([1.0, 0.5, 0.1])]
387388
nside = 4
388389
ncorr = 1
389-
seed = 42
390390

391391
gaussian_fields = list(glass.fields._generate_grf(gls, nside))
392392

393393
assert gaussian_fields[0].shape == (hp.nside2npix(nside),)
394394

395395
# requires resetting the RNG for reproducibility
396-
rng = np.random.default_rng(seed=seed)
396+
rng = _utils.rng_dispatcher(xp=np)
397397
gaussian_fields = list(glass.fields._generate_grf(gls, nside, rng=rng))
398398

399399
assert gaussian_fields[0].shape == (hp.nside2npix(nside),)
400400

401401
# requires resetting the RNG for reproducibility
402-
rng = np.random.default_rng(seed=seed)
402+
rng = _utils.rng_dispatcher(xp=np)
403403
new_gaussian_fields = list(
404404
glass.fields._generate_grf(gls, nside, ncorr=ncorr, rng=rng),
405405
)

tests/core/test_jax.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
import jax.numpy as jnp
12
import pytest
3+
from jax.typing import ArrayLike
24

3-
jax = pytest.importorskip("jax", reason="tests require jax")
4-
5-
import jax.numpy as jnp # noqa: E402
6-
from jax.typing import ArrayLike # noqa: E402
5+
import glass._array_api_utils as _utils
6+
from glass.jax import Generator
77

8-
from glass.jax import Generator # noqa: E402
8+
jax = pytest.importorskip("jax", reason="tests require jax")
99

1010
SEED = 42
1111

1212

1313
def test_init() -> None:
14-
rng = Generator(SEED)
14+
rng = _utils.rng_dispatcher(xp=jnp)
1515
assert isinstance(rng, Generator)
1616
assert isinstance(rng.key, jax.Array)
1717
assert jax.dtypes.issubdtype(rng.key.dtype, jax.dtypes.prng_key)
@@ -31,18 +31,18 @@ def test_from_key() -> None:
3131

3232

3333
def test_key() -> None:
34-
rng = Generator(SEED)
35-
rngkey, outkey = jax.random.split(rng.key, 2)
36-
key = rng.split()
37-
assert jnp.all(rng.key == rngkey)
34+
rng = _utils.rng_dispatcher(xp=jnp)
35+
rngkey, outkey = jax.random.split(rng.key, 2) # type: ignore[union-attr]
36+
key = rng.split() # type: ignore[union-attr]
37+
assert jnp.all(rng.key == rngkey) # type: ignore[union-attr]
3838
assert jnp.all(key == outkey)
3939

4040

4141
def test_spawn() -> None:
42-
rng = Generator(SEED)
43-
key, *subkeys = jax.random.split(rng.key, 4)
44-
subrngs = rng.spawn(3)
45-
assert rng.key == key
42+
rng = _utils.rng_dispatcher(xp=jnp)
43+
key, *subkeys = jax.random.split(rng.key, 4) # type: ignore[union-attr]
44+
subrngs = rng.spawn(3) # type: ignore[union-attr]
45+
assert rng.key == key # type: ignore[union-attr]
4646
assert isinstance(subrngs, list)
4747
assert len(subrngs) == 3
4848
for subrng, subkey in zip(subrngs, subkeys, strict=False):
@@ -51,48 +51,48 @@ def test_spawn() -> None:
5151

5252

5353
def test_random() -> None:
54-
rng = Generator(SEED)
55-
key = rng.key
54+
rng = _utils.rng_dispatcher(xp=jnp)
55+
key = rng.key # type: ignore[union-attr]
5656
rvs = rng.random(size=10_000)
57-
assert rng.key != key
57+
assert rng.key != key # type: ignore[union-attr]
5858
assert rvs.shape == (10_000,)
5959
assert jnp.min(rvs) >= 0.0
6060
assert jnp.max(rvs) < 1.0
6161
assert isinstance(rvs, ArrayLike)
6262

6363

6464
def test_normal() -> None:
65-
rng = Generator(SEED)
66-
key = rng.key
65+
rng = _utils.rng_dispatcher(xp=jnp)
66+
key = rng.key # type: ignore[union-attr]
6767
rvs = rng.normal(1, 2, size=10_000)
68-
assert rng.key != key
68+
assert rng.key != key # type: ignore[union-attr]
6969
assert rvs.shape == (10_000,)
7070
assert isinstance(rvs, ArrayLike)
7171

7272

7373
def test_standard_normal() -> None:
74-
rng = Generator(SEED)
75-
key = rng.key
74+
rng = _utils.rng_dispatcher(xp=jnp)
75+
key = rng.key # type: ignore[union-attr]
7676
rvs = rng.standard_normal(size=10_000)
77-
assert rng.key != key
77+
assert rng.key != key # type: ignore[union-attr]
7878
assert rvs.shape == (10_000,)
7979
assert isinstance(rvs, ArrayLike)
8080

8181

8282
def test_poisson() -> None:
83-
rng = Generator(SEED)
84-
key = rng.key
83+
rng = _utils.rng_dispatcher(xp=jnp)
84+
key = rng.key # type: ignore[union-attr]
8585
rvs = rng.poisson(lam=1, size=10_000)
86-
assert rng.key != key
86+
assert rng.key != key # type: ignore[union-attr]
8787
assert rvs.shape == (10_000,)
8888
assert isinstance(rvs, ArrayLike)
8989

9090

9191
def test_uniform() -> None:
92-
rng = Generator(SEED)
93-
key = rng.key
92+
rng = _utils.rng_dispatcher(xp=jnp)
93+
key = rng.key # type: ignore[union-attr]
9494
rvs = rng.uniform(size=10_000)
95-
assert rng.key != key
95+
assert rng.key != key # type: ignore[union-attr]
9696
assert rvs.shape == (10_000,)
9797
assert jnp.min(rvs) >= 0.0
9898
assert jnp.max(rvs) < 1.0

0 commit comments

Comments
 (0)