Skip to content

Commit ae70202

Browse files
committed
feat: add benchmarks for spergel profile
1 parent b769e97 commit ae70202

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/jax/test_benchmarks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,27 @@ def _run():
159159

160160
dt = _run_benchmarks(benchmark, kind, _run)
161161
print(f"time: {dt:0.4g} ms", end=" ")
162+
163+
164+
@pytest.mark.parametrize("kind", ["compile", "run"])
165+
def test_benchmark_spergel(benchmark, kind):
166+
def _run_jax():
167+
gal = jgs.Spergel(nu=-0.6, scale_radius=4.0)
168+
psf = jgs.Gaussian(fwhm=0.9)
169+
obj = jgs.Convolve([gal, psf])
170+
obj.drawImage(nx=51, ny=51, scale=0.2).array.block_until_ready()
171+
172+
dt = _run_benchmarks(benchmark, kind, _run_jax)
173+
print(f"jax-galsim time: {dt:0.4g} ms", end=" ")
174+
175+
176+
@pytest.mark.parametrize("kind", ["compile", "run"])
177+
def test_benchmark_spergel_galsim(benchmark, kind):
178+
def _run():
179+
gal = _galsim.Spergel(nu=-0.6, scale_radius=4.0)
180+
psf = _galsim.Gaussian(fwhm=0.9)
181+
obj = _galsim.Convolve([gal, psf])
182+
obj.drawImage(nx=51, ny=51, scale=0.2)
183+
184+
dt = _run_benchmarks(benchmark, kind, _run)
185+
print(f"galsim time: {dt:0.4g} ms", end=" ")

0 commit comments

Comments
 (0)