Skip to content

Commit c169703

Browse files
authored
test: add vmap gradient test for Gaussian object
1 parent 5a9acd9 commit c169703

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

tests/jax/test_deriv_gsobject.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,27 @@ def _run(val_):
7070
atol = 1e-5
7171

7272
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol)
73+
74+
75+
def test_deriv_gsobject_params_vmap():
76+
val = jnp.array([2.0, 3.0])
77+
eps = 1e-5
78+
79+
def _run(val_):
80+
return jnp.max(
81+
jgs.Gaussian(
82+
half_light_radius=val_,
83+
gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64),
84+
)
85+
.drawImage(nx=48, ny=48, scale=0.2)
86+
.array[24, 24]
87+
** 2
88+
)
89+
90+
_vmap_run = jax.vmap(_run)
91+
gfunc = jax.jit(jax.vmap(jax.grad(_run)))
92+
gval = gfunc(val)
93+
94+
gfdiff = (_vmap_run(val + eps) - _vmap_run(val - eps)) / 2.0 / eps
95+
96+
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6)

0 commit comments

Comments
 (0)