1+ import jax .numpy as jnp
12import pytest
3+ from jax .typing import ArrayLike
24
3- jax = pytest .importorskip ("jax" , reason = "tests require jax" )
4-
5- import jax .numpy as jnp # noqa: E402
6- from jax .typing import ArrayLike # noqa: E402
5+ import glass ._array_api_utils as _utils
6+ from glass .jax import Generator
77
8- from glass . jax import Generator # noqa: E402
8+ jax = pytest . importorskip ( " jax" , reason = "tests require jax" )
99
1010SEED = 42
1111
1212
1313def test_init () -> None :
14- rng = Generator ( SEED )
14+ rng = _utils . rng_dispatcher ( xp = jnp )
1515 assert isinstance (rng , Generator )
1616 assert isinstance (rng .key , jax .Array )
1717 assert jax .dtypes .issubdtype (rng .key .dtype , jax .dtypes .prng_key )
@@ -31,18 +31,18 @@ def test_from_key() -> None:
3131
3232
3333def test_key () -> None :
34- rng = Generator ( SEED )
35- rngkey , outkey = jax .random .split (rng .key , 2 )
36- key = rng .split ()
37- assert jnp .all (rng .key == rngkey )
34+ rng = _utils . rng_dispatcher ( xp = jnp )
35+ rngkey , outkey = jax .random .split (rng .key , 2 ) # type: ignore[union-attr]
36+ key = rng .split () # type: ignore[union-attr]
37+ assert jnp .all (rng .key == rngkey ) # type: ignore[union-attr]
3838 assert jnp .all (key == outkey )
3939
4040
4141def test_spawn () -> None :
42- rng = Generator ( SEED )
43- key , * subkeys = jax .random .split (rng .key , 4 )
44- subrngs = rng .spawn (3 )
45- assert rng .key == key
42+ rng = _utils . rng_dispatcher ( xp = jnp )
43+ key , * subkeys = jax .random .split (rng .key , 4 ) # type: ignore[union-attr]
44+ subrngs = rng .spawn (3 ) # type: ignore[union-attr]
45+ assert rng .key == key # type: ignore[union-attr]
4646 assert isinstance (subrngs , list )
4747 assert len (subrngs ) == 3
4848 for subrng , subkey in zip (subrngs , subkeys , strict = False ):
@@ -51,48 +51,48 @@ def test_spawn() -> None:
5151
5252
5353def test_random () -> None :
54- rng = Generator ( SEED )
55- key = rng .key
54+ rng = _utils . rng_dispatcher ( xp = jnp )
55+ key = rng .key # type: ignore[union-attr]
5656 rvs = rng .random (size = 10_000 )
57- assert rng .key != key
57+ assert rng .key != key # type: ignore[union-attr]
5858 assert rvs .shape == (10_000 ,)
5959 assert jnp .min (rvs ) >= 0.0
6060 assert jnp .max (rvs ) < 1.0
6161 assert isinstance (rvs , ArrayLike )
6262
6363
6464def test_normal () -> None :
65- rng = Generator ( SEED )
66- key = rng .key
65+ rng = _utils . rng_dispatcher ( xp = jnp )
66+ key = rng .key # type: ignore[union-attr]
6767 rvs = rng .normal (1 , 2 , size = 10_000 )
68- assert rng .key != key
68+ assert rng .key != key # type: ignore[union-attr]
6969 assert rvs .shape == (10_000 ,)
7070 assert isinstance (rvs , ArrayLike )
7171
7272
7373def test_standard_normal () -> None :
74- rng = Generator ( SEED )
75- key = rng .key
74+ rng = _utils . rng_dispatcher ( xp = jnp )
75+ key = rng .key # type: ignore[union-attr]
7676 rvs = rng .standard_normal (size = 10_000 )
77- assert rng .key != key
77+ assert rng .key != key # type: ignore[union-attr]
7878 assert rvs .shape == (10_000 ,)
7979 assert isinstance (rvs , ArrayLike )
8080
8181
8282def test_poisson () -> None :
83- rng = Generator ( SEED )
84- key = rng .key
83+ rng = _utils . rng_dispatcher ( xp = jnp )
84+ key = rng .key # type: ignore[union-attr]
8585 rvs = rng .poisson (lam = 1 , size = 10_000 )
86- assert rng .key != key
86+ assert rng .key != key # type: ignore[union-attr]
8787 assert rvs .shape == (10_000 ,)
8888 assert isinstance (rvs , ArrayLike )
8989
9090
9191def test_uniform () -> None :
92- rng = Generator ( SEED )
93- key = rng .key
92+ rng = _utils . rng_dispatcher ( xp = jnp )
93+ key = rng .key # type: ignore[union-attr]
9494 rvs = rng .uniform (size = 10_000 )
95- assert rng .key != key
95+ assert rng .key != key # type: ignore[union-attr]
9696 assert rvs .shape == (10_000 ,)
9797 assert jnp .min (rvs ) >= 0.0
9898 assert jnp .max (rvs ) < 1.0
0 commit comments