Skip to content

Commit 7c93e54

Browse files
committed
test: redo benchmarks to just cover jax and add timing comps
1 parent 1b2baf5 commit 7c93e54

File tree

3 files changed

+92
-33
lines changed

3 files changed

+92
-33
lines changed

jax_galsim/spergel.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,34 +267,34 @@ def scale_radius(self):
267267
def _r0(self):
268268
return self.scale_radius
269269

270-
@property
270+
@lazy_property
271271
def _inv_r0(self):
272272
return 1.0 / self._r0
273273

274-
@property
274+
@lazy_property
275275
def _r0_sq(self):
276276
return self._r0 * self._r0
277277

278-
@property
278+
@lazy_property
279279
def _inv_r0_sq(self):
280280
return self._inv_r0 * self._inv_r0
281281

282-
@property
282+
@lazy_property
283283
@implements(_galsim.spergel.Spergel.half_light_radius)
284284
def half_light_radius(self):
285285
return self._r0 * calculateFluxRadius(0.5, self.nu)
286286

287-
@property
287+
@lazy_property
288288
def _shootxnorm(self):
289289
"""Normalization for photon shooting"""
290290
return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu))
291291

292-
@property
292+
@lazy_property
293293
def _xnorm(self):
294294
"""Normalization of xValue"""
295295
return self._shootxnorm * self.flux * self._inv_r0_sq
296296

297-
@property
297+
@lazy_property
298298
def _xnorm0(self):
299299
"""return z^nu K_nu(z) for z=0"""
300300
return jax.lax.select(
@@ -338,21 +338,21 @@ def __str__(self):
338338
s += ")"
339339
return s
340340

341-
@property
341+
@lazy_property
342342
def _maxk(self):
343343
"""(1+ (k r0)^2)^(-1-nu) = maxk_threshold"""
344344
res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0
345345
return jnp.sqrt(res) / self._r0
346346

347-
@property
347+
@lazy_property
348348
def _stepk(self):
349349
R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu)
350350
R *= self._r0
351351
# Go to at least 5*hlr
352352
R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius)
353353
return jnp.pi / R
354354

355-
@property
355+
@lazy_property
356356
def _max_sb(self):
357357
# from SBSpergelImpl.h
358358
return jnp.abs(self._xnorm) * self._xnorm0

tests/jax/test_benchmarks.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -179,18 +179,12 @@ def test_benchmark_spergel_conv(benchmark, kind):
179179
dt = _run_benchmarks(
180180
benchmark, kind, lambda: _run_spergel_bench_conv_jit().block_until_ready()
181181
)
182-
print(f"jax-galsim time: {dt:0.4g} ms", end=" ")
183-
184-
185-
@pytest.mark.parametrize("kind", ["compile", "run"])
186-
def test_benchmark_spergel_conv_galsim(benchmark, kind):
187-
dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_conv(_galsim))
188-
print(f"galsim time: {dt:0.4g} ms", end=" ")
182+
print(f"time: {dt:0.4g} ms", end=" ")
189183

190184

191185
def _run_spergel_bench_xvalue(gsmod):
192186
obj = gsmod.Spergel(nu=-0.6, scale_radius=5)
193-
return obj.drawImage(nx=50, ny=50, scale=0.2, method="no_pixel").array
187+
return obj.drawImage(nx=1024, ny=1204, scale=0.05, method="no_pixel").array
194188

195189

196190
_run_spergel_bench_xvalue_jit = jax.jit(partial(_run_spergel_bench_xvalue, jgs))
@@ -201,18 +195,12 @@ def test_benchmark_spergel_xvalue(benchmark, kind):
201195
dt = _run_benchmarks(
202196
benchmark, kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready()
203197
)
204-
print(f"jax-galsim time: {dt:0.4g} ms", end=" ")
205-
206-
207-
@pytest.mark.parametrize("kind", ["compile", "run"])
208-
def test_benchmark_spergel_xvalue_galsim(benchmark, kind):
209-
dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_xvalue(_galsim))
210-
print(f"galsim time: {dt:0.4g} ms", end=" ")
198+
print(f"time: {dt:0.4g} ms", end=" ")
211199

212200

213201
def _run_spergel_bench_kvalue(gsmod):
214202
obj = gsmod.Spergel(nu=-0.6, scale_radius=5)
215-
return obj.drawKImage(nx=50, ny=50, scale=0.2).array
203+
return obj.drawKImage(nx=1024, ny=1204, scale=0.05).array
216204

217205

218206
_run_spergel_bench_kvalue_jit = jax.jit(partial(_run_spergel_bench_kvalue, jgs))
@@ -223,10 +211,4 @@ def test_benchmark_spergel_kvalue(benchmark, kind):
223211
dt = _run_benchmarks(
224212
benchmark, kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready()
225213
)
226-
print(f"jax-galsim time: {dt:0.4g} ms", end=" ")
227-
228-
229-
@pytest.mark.parametrize("kind", ["compile", "run"])
230-
def test_benchmark_spergel_kvalue_galsim(benchmark, kind):
231-
dt = _run_benchmarks(benchmark, kind, lambda: _run_spergel_bench_kvalue(_galsim))
232-
print(f"galsim time: {dt:0.4g} ms", end=" ")
214+
print(f"time: {dt:0.4g} ms", end=" ")

tests/jax/test_spergel_comp_galsim.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1+
import galsim as _galsim
12
import galsim as gs
3+
import jax
24
import numpy as np
35
import pytest
6+
from test_benchmarks import (
7+
_run_spergel_bench_conv,
8+
_run_spergel_bench_conv_jit,
9+
_run_spergel_bench_kvalue,
10+
_run_spergel_bench_kvalue_jit,
11+
_run_spergel_bench_xvalue,
12+
_run_spergel_bench_xvalue_jit,
13+
)
414

515
import jax_galsim as jgs
16+
from jax_galsim.core.testing import time_code_block
617

718

819
@pytest.mark.parametrize(
@@ -78,3 +89,69 @@ def test_spergel_comp_galsim_xvalue(nu, scale_radius, x, y):
7889
s_gs = gs.Spergel(nu=nu, scale_radius=scale_radius)
7990

8091
np.testing.assert_allclose(s_jgs.xValue(x, y), s_gs.xValue(x, y), rtol=1e-5)
92+
93+
94+
def _run_time_test(kind, func):
95+
if kind == "compile":
96+
97+
def _run():
98+
jax.clear_caches()
99+
func()
100+
101+
elif kind == "run":
102+
# run once to compile
103+
func()
104+
105+
def _run():
106+
func()
107+
108+
else:
109+
raise ValueError(f"kind={kind} not recognized")
110+
111+
tot_time = 0
112+
for _ in range(3):
113+
with time_code_block(quiet=True) as tr:
114+
_run()
115+
tot_time += tr.dt
116+
117+
return tot_time / 3
118+
119+
120+
@pytest.mark.parametrize("kind", ["compile", "run"])
121+
def test_spergel_comp_galsim_perf_conv(benchmark, kind):
122+
dt = _run_time_test(kind, lambda: _run_spergel_bench_conv_jit().block_until_ready())
123+
print(f"\njax-galsim time: {dt:0.4g} ms")
124+
125+
dt = _run_time_test(
126+
kind,
127+
lambda: _run_spergel_bench_conv(_galsim),
128+
)
129+
print(f" galsim time: {dt:0.4g} ms")
130+
131+
132+
@pytest.mark.parametrize("kind", ["compile", "run"])
133+
def test_spergel_comp_galsim_perf_kvalue(benchmark, kind):
134+
dt = _run_time_test(
135+
kind, lambda: _run_spergel_bench_kvalue_jit().block_until_ready()
136+
)
137+
print(f"\njax-galsim time: {dt:0.4g} ms")
138+
139+
dt = _run_time_test(
140+
kind,
141+
lambda: _run_spergel_bench_kvalue(_galsim),
142+
)
143+
print(f" galsim time: {dt:0.4g} ms")
144+
145+
146+
@pytest.mark.parametrize("kind", ["compile", "run"])
147+
def test_spergel_comp_galsim_perf_xvalue(benchmark, kind):
148+
dt = _run_time_test(
149+
kind, lambda: _run_spergel_bench_xvalue_jit().block_until_ready()
150+
)
151+
print(f"\njax-galsim time: {dt:0.4g} ms")
152+
153+
dt = _run_time_test(
154+
kind,
155+
lambda: _run_spergel_bench_xvalue(_galsim),
156+
)
157+
print(f" galsim time: {dt:0.4g} ms")

0 commit comments

Comments
 (0)