11from __future__ import annotations
22
3+ import importlib .util
34from typing import TYPE_CHECKING
45
56import healpy as hp
1617
1718 from tests .fixtures .helper_classes import Compare
1819
20+ HAVE_JAX = importlib .util .find_spec ("jax" ) is not None
21+
1922
2023@pytest .fixture (scope = "session" )
2124def not_triangle_numbers () -> list [int ]:
@@ -139,35 +142,75 @@ def test_iternorm(xp: ModuleType) -> None:
139142 assert s .shape == (3 ,)
140143
141144
142- def test_cls2cov (compare : type [Compare ], xp : ModuleType ) -> None :
143- # Call jax version of iternorm once jax version is written
144- if xp .__name__ == "jax.numpy" :
145- pytest .skip ("Arrays in cls2cov are not immutable, so do not support jax" )
145+ @pytest .mark .skipif (not HAVE_JAX , reason = "test requires jax" )
146+ def test_cls2cov_jax (compare : type [Compare ], jnp : ModuleType ) -> None :
147+ nl , nf , nc = 3 , 3 , 2
146148
149+ generator = glass .cls2cov (
150+ [
151+ jnp .asarray (arr )
152+ for arr in [
153+ [1.0 , 0.5 , 0.3 ],
154+ [0.8 , 0.4 , 0.2 ],
155+ [0.7 , 0.6 , 0.1 ],
156+ [0.9 , 0.5 , 0.3 ],
157+ [0.6 , 0.3 , 0.2 ],
158+ [0.8 , 0.7 , 0.4 ],
159+ ]
160+ ],
161+ nl ,
162+ nf ,
163+ nc ,
164+ )
165+
166+ cov1 = jnp .asarray (next (generator ), copy = False )
167+ cov2 = jnp .asarray (next (generator ), copy = False )
168+ cov3 = next (generator )
169+
170+ assert cov1 .shape == (nl , nc + 1 )
171+ assert cov2 .shape == (nl , nc + 1 )
172+ assert cov3 .shape == (nl , nc + 1 )
173+
174+ assert cov1 .dtype == jnp .float64
175+ assert cov2 .dtype == jnp .float64
176+ assert cov3 .dtype == jnp .float64
177+
178+ # cov1 has the expected value for the first iteration (different to cov1_copy)
179+ compare .assert_allclose (cov1 [:, 0 ], jnp .asarray ([0.5 , 0.25 , 0.15 ]))
180+
181+ # The copies should not be equal
182+ with pytest .raises (AssertionError , match = "Not equal to tolerance" ):
183+ compare .assert_allclose (cov1 , cov2 )
184+
185+ with pytest .raises (AssertionError , match = "Not equal to tolerance" ):
186+ compare .assert_allclose (cov2 , cov3 )
187+
188+
189+ def test_cls2cov_no_jax (compare : type [Compare ], xpb : ModuleType ) -> None :
147190 # check output values and shape
148191
149192 nl , nf , nc = 3 , 2 , 2
150193
151194 generator = glass .cls2cov (
152- [xp .asarray ([1.0 , 0.5 , 0.3 ]), None , xp .asarray ([0.7 , 0.6 , 0.1 ])],
195+ [xpb .asarray ([1.0 , 0.5 , 0.3 ]), None , xpb .asarray ([0.7 , 0.6 , 0.1 ])],
153196 nl ,
154197 nf ,
155198 nc ,
156199 )
157200 cov = next (generator )
158201
159202 assert cov .shape == (nl , nc + 1 )
160- assert cov .dtype == xp .float64
203+ assert cov .dtype == xpb .float64
161204
162- compare .assert_allclose (cov [:, 0 ], xp .asarray ([0.5 , 0.25 , 0.15 ]))
205+ compare .assert_allclose (cov [:, 0 ], xpb .asarray ([0.5 , 0.25 , 0.15 ]))
163206 compare .assert_allclose (cov [:, 1 ], 0 )
164207 compare .assert_allclose (cov [:, 2 ], 0 )
165208
166209 # test negative value error
167210
168211 generator = glass .cls2cov (
169212 [
170- xp .asarray (arr )
213+ xpb .asarray (arr )
171214 for arr in [
172215 [- 1.0 , 0.5 , 0.3 ],
173216 [0.8 , 0.4 , 0.2 ],
@@ -187,7 +230,7 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
187230
188231 generator = glass .cls2cov (
189232 [
190- xp .asarray (arr )
233+ xpb .asarray (arr )
191234 for arr in [
192235 [1.0 , 0.5 , 0.3 ],
193236 [0.8 , 0.4 , 0.2 ],
@@ -202,23 +245,34 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
202245 nc ,
203246 )
204247
205- cov1 = xp .asarray (next (generator ), copy = True )
206- cov2 = xp .asarray (next (generator ), copy = True )
248+ cov1 = xpb .asarray (next (generator ), copy = False )
249+ cov1_copy = xpb .asarray (cov1 , copy = True )
250+ cov2 = xpb .asarray (next (generator ), copy = False )
251+ cov2_copy = xpb .asarray (cov2 , copy = True )
207252 cov3 = next (generator )
208253
209254 assert cov1 .shape == (nl , nc + 1 )
210255 assert cov2 .shape == (nl , nc + 1 )
211256 assert cov3 .shape == (nl , nc + 1 )
212257
213- assert cov1 .dtype == xp .float64
214- assert cov2 .dtype == xp .float64
215- assert cov3 .dtype == xp .float64
258+ assert cov1 .dtype == xpb .float64
259+ assert cov2 .dtype == xpb .float64
260+ assert cov3 .dtype == xpb .float64
261+
262+ # cov1|2|3 reuse the same data, so should all equal the third result
263+ compare .assert_allclose (cov1 [:, 0 ], xpb .asarray ([0.45 , 0.25 , 0.15 ]))
264+ compare .assert_allclose (cov1 , cov2 )
265+ compare .assert_allclose (cov2 , cov3 )
266+
267+ # cov1 has the expected value for the first iteration (different to cov1_copy)
268+ compare .assert_allclose (cov1_copy [:, 0 ], xpb .asarray ([0.5 , 0.25 , 0.15 ]))
216269
270+ # The copies should not be equal
217271 with pytest .raises (AssertionError , match = "Not equal to tolerance" ):
218- compare .assert_allclose (cov1 , cov2 )
272+ compare .assert_allclose (cov1_copy , cov2_copy )
219273
220274 with pytest .raises (AssertionError , match = "Not equal to tolerance" ):
221- compare .assert_allclose (cov2 , cov3 )
275+ compare .assert_allclose (cov2_copy , cov3 )
222276
223277
224278def test_lognormal_gls () -> None :
0 commit comments