Skip to content

Commit d0ff6e6

Browse files
gh-637: Use xpx.at in fields.cls2cov (#874)
Co-authored-by: Patrick J. Roddy <patrickjamesroddy@gmail.com>
1 parent bfaf93e commit d0ff6e6

File tree

4 files changed

+83
-34
lines changed

4 files changed

+83
-34
lines changed

glass/fields.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,12 @@ def cls2cov(
213213
for j in range(nf):
214214
begin, end = end, end + j + 1
215215
for i, cl in enumerate(cls[begin:end][: nc + 1]):
216-
if i == 0 and np.any(xp.less(cl, 0)):
216+
if i == 0 and xp.any(xp.less(cl, 0)):
217217
msg = "negative values in cl"
218218
raise ValueError(msg)
219-
n = cl.size
220-
cov[:n, i] = cl
221-
cov[n:, i] = 0
219+
n = cl.shape[0]
220+
cov = xpx.at(cov)[:n, i].set(cl)
221+
cov = xpx.at(cov)[n:, i].set(0.0)
222222
cov /= 2
223223
yield cov
224224

tests/benchmarks/test_fields.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,11 @@ def test_cls2cov(
127127
benchmark: BenchmarkFixture,
128128
compare: Compare,
129129
generator_consumer: GeneratorConsumer,
130-
urngb: UnifiedGenerator,
131130
xpb: ModuleType,
132131
) -> None:
133132
"""Benchmarks for glass.cls2cov."""
134-
# check output values and shape
135-
136133
nl, nf, nc = 3, 2, 2
137-
array_in = [urngb.random(3) for _ in range(1_000)]
134+
array_in = [xpb.arange(i + 1.0, i + 4.0) for i in range(1_000)]
138135

139136
def function_to_benchmark() -> list[Any]:
140137
generator = glass.cls2cov(
@@ -151,16 +148,8 @@ def function_to_benchmark() -> list[Any]:
151148
assert cov.shape == (nl, nc + 1)
152149
assert cov.dtype == xpb.float64
153150

154-
compare.assert_allclose(
155-
cov[:, 0],
156-
xpb.asarray([0.348684, 0.047089, 0.487811]),
157-
atol=1e-6,
158-
)
159-
compare.assert_allclose(
160-
cov[:, 1],
161-
[0.38057, 0.393032, 0.064057],
162-
atol=1e-6,
163-
)
151+
compare.assert_allclose(cov[:, 0], xpb.asarray([1.0, 1.5, 2.0]))
152+
compare.assert_allclose(cov[:, 1], xpb.asarray([1.5, 2.0, 2.5]))
164153
compare.assert_allclose(cov[:, 2], 0)
165154

166155

tests/core/test_fields.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import importlib.util
34
from typing import TYPE_CHECKING
45

56
import healpy as hp
@@ -16,6 +17,8 @@
1617

1718
from tests.fixtures.helper_classes import Compare
1819

20+
HAVE_JAX = importlib.util.find_spec("jax") is not None
21+
1922

2023
@pytest.fixture(scope="session")
2124
def not_triangle_numbers() -> list[int]:
@@ -139,35 +142,75 @@ def test_iternorm(xp: ModuleType) -> None:
139142
assert s.shape == (3,)
140143

141144

142-
def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
143-
# Call jax version of iternorm once jax version is written
144-
if xp.__name__ == "jax.numpy":
145-
pytest.skip("Arrays in cls2cov are not immutable, so do not support jax")
145+
@pytest.mark.skipif(not HAVE_JAX, reason="test requires jax")
146+
def test_cls2cov_jax(compare: type[Compare], jnp: ModuleType) -> None:
147+
nl, nf, nc = 3, 3, 2
146148

149+
generator = glass.cls2cov(
150+
[
151+
jnp.asarray(arr)
152+
for arr in [
153+
[1.0, 0.5, 0.3],
154+
[0.8, 0.4, 0.2],
155+
[0.7, 0.6, 0.1],
156+
[0.9, 0.5, 0.3],
157+
[0.6, 0.3, 0.2],
158+
[0.8, 0.7, 0.4],
159+
]
160+
],
161+
nl,
162+
nf,
163+
nc,
164+
)
165+
166+
cov1 = jnp.asarray(next(generator), copy=False)
167+
cov2 = jnp.asarray(next(generator), copy=False)
168+
cov3 = next(generator)
169+
170+
assert cov1.shape == (nl, nc + 1)
171+
assert cov2.shape == (nl, nc + 1)
172+
assert cov3.shape == (nl, nc + 1)
173+
174+
assert cov1.dtype == jnp.float64
175+
assert cov2.dtype == jnp.float64
176+
assert cov3.dtype == jnp.float64
177+
178+
# cov1 has the expected value for the first iteration (different to cov1_copy)
179+
compare.assert_allclose(cov1[:, 0], jnp.asarray([0.5, 0.25, 0.15]))
180+
181+
# The copies should not be equal
182+
with pytest.raises(AssertionError, match="Not equal to tolerance"):
183+
compare.assert_allclose(cov1, cov2)
184+
185+
with pytest.raises(AssertionError, match="Not equal to tolerance"):
186+
compare.assert_allclose(cov2, cov3)
187+
188+
189+
def test_cls2cov_no_jax(compare: type[Compare], xpb: ModuleType) -> None:
147190
# check output values and shape
148191

149192
nl, nf, nc = 3, 2, 2
150193

151194
generator = glass.cls2cov(
152-
[xp.asarray([1.0, 0.5, 0.3]), None, xp.asarray([0.7, 0.6, 0.1])],
195+
[xpb.asarray([1.0, 0.5, 0.3]), None, xpb.asarray([0.7, 0.6, 0.1])],
153196
nl,
154197
nf,
155198
nc,
156199
)
157200
cov = next(generator)
158201

159202
assert cov.shape == (nl, nc + 1)
160-
assert cov.dtype == xp.float64
203+
assert cov.dtype == xpb.float64
161204

162-
compare.assert_allclose(cov[:, 0], xp.asarray([0.5, 0.25, 0.15]))
205+
compare.assert_allclose(cov[:, 0], xpb.asarray([0.5, 0.25, 0.15]))
163206
compare.assert_allclose(cov[:, 1], 0)
164207
compare.assert_allclose(cov[:, 2], 0)
165208

166209
# test negative value error
167210

168211
generator = glass.cls2cov(
169212
[
170-
xp.asarray(arr)
213+
xpb.asarray(arr)
171214
for arr in [
172215
[-1.0, 0.5, 0.3],
173216
[0.8, 0.4, 0.2],
@@ -187,7 +230,7 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
187230

188231
generator = glass.cls2cov(
189232
[
190-
xp.asarray(arr)
233+
xpb.asarray(arr)
191234
for arr in [
192235
[1.0, 0.5, 0.3],
193236
[0.8, 0.4, 0.2],
@@ -202,23 +245,34 @@ def test_cls2cov(compare: type[Compare], xp: ModuleType) -> None:
202245
nc,
203246
)
204247

205-
cov1 = xp.asarray(next(generator), copy=True)
206-
cov2 = xp.asarray(next(generator), copy=True)
248+
cov1 = xpb.asarray(next(generator), copy=False)
249+
cov1_copy = xpb.asarray(cov1, copy=True)
250+
cov2 = xpb.asarray(next(generator), copy=False)
251+
cov2_copy = xpb.asarray(cov2, copy=True)
207252
cov3 = next(generator)
208253

209254
assert cov1.shape == (nl, nc + 1)
210255
assert cov2.shape == (nl, nc + 1)
211256
assert cov3.shape == (nl, nc + 1)
212257

213-
assert cov1.dtype == xp.float64
214-
assert cov2.dtype == xp.float64
215-
assert cov3.dtype == xp.float64
258+
assert cov1.dtype == xpb.float64
259+
assert cov2.dtype == xpb.float64
260+
assert cov3.dtype == xpb.float64
261+
262+
# cov1|2|3 reuse the same data, so should all equal the third result
263+
compare.assert_allclose(cov1[:, 0], xpb.asarray([0.45, 0.25, 0.15]))
264+
compare.assert_allclose(cov1, cov2)
265+
compare.assert_allclose(cov2, cov3)
266+
267+
# cov1 has the expected value for the first iteration (different to cov1_copy)
268+
compare.assert_allclose(cov1_copy[:, 0], xpb.asarray([0.5, 0.25, 0.15]))
216269

270+
# The copies should not be equal
217271
with pytest.raises(AssertionError, match="Not equal to tolerance"):
218-
compare.assert_allclose(cov1, cov2)
272+
compare.assert_allclose(cov1_copy, cov2_copy)
219273

220274
with pytest.raises(AssertionError, match="Not equal to tolerance"):
221-
compare.assert_allclose(cov2, cov3)
275+
compare.assert_allclose(cov2_copy, cov3)
222276

223277

224278
def test_lognormal_gls() -> None:

tests/fixtures/array_backends.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def xpb(request: pytest.FixtureRequest) -> ModuleType:
125125
return request.param # type: ignore[no-any-return]
126126

127127

128+
@pytest.fixture(scope="session")
129+
def jnp() -> ModuleType:
130+
"""Fixture for the jax.numpy array backend."""
131+
return xp_available_backends["jax.numpy"]
132+
133+
128134
@pytest.fixture(scope="session")
129135
def uxpx(xp: ModuleType) -> _utils.XPAdditions:
130136
"""

0 commit comments

Comments
 (0)