Skip to content

Commit 4922f21

Browse files
committed
test: more tests for bounds set
1 parent aa8d587 commit 4922f21

1 file changed

Lines changed: 110 additions & 1 deletion

File tree

tests/jax/test_bounds_jax.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import jax
24
import jax.numpy as jnp
35
import numpy as np
@@ -217,7 +219,7 @@ def test_bounds_jax_vmap_plus_raises_int():
217219
_plus_bounds_pos_far_away_float(bnds)
218220

219221

220-
def test_bounds_jax_int_set():
222+
def test_bounds_jax_int_set_static():
221223
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
222224

223225
bnds.xmin = 11.0
@@ -247,3 +249,110 @@ def test_bounds_jax_int_set():
247249
bnds.deltay = jnp.array(13, dtype=float)
248250
assert isinstance(bnds.deltay, int)
249251
assert bnds.deltay == 13
252+
253+
254+
def test_bounds_jax_int_set_dynamic():
255+
bnds = jax_galsim.BoundsI(
256+
xmin=jnp.array(1), ymin=jnp.array(2), deltax=10, deltay=11
257+
)
258+
259+
bnds.xmin = 11.0
260+
assert isinstance(bnds.xmin, jnp.ndarray)
261+
assert bnds.xmin == 11
262+
bnds.xmin = jnp.array(12, dtype=float)
263+
assert isinstance(bnds.xmin, jnp.ndarray)
264+
assert bnds.xmin == 12
265+
266+
bnds.ymin = 12.0
267+
assert isinstance(bnds.ymin, jnp.ndarray)
268+
assert bnds.ymin == 12
269+
bnds.ymin = jnp.array(13, dtype=float)
270+
assert isinstance(bnds.ymin, jnp.ndarray)
271+
assert bnds.ymin == 13
272+
273+
bnds.deltax = 11.0
274+
assert isinstance(bnds.deltax, int)
275+
assert bnds.deltax == 11
276+
bnds.deltax = jnp.array(12, dtype=float)
277+
assert isinstance(bnds.deltax, int)
278+
assert bnds.deltax == 12
279+
280+
bnds.deltay = 12.0
281+
assert isinstance(bnds.deltay, int)
282+
assert bnds.deltay == 12
283+
bnds.deltay = jnp.array(13, dtype=float)
284+
assert isinstance(bnds.deltay, int)
285+
assert bnds.deltay == 13
286+
287+
288+
def test_bounds_jax_int_set_jit_raises():
289+
@jax.jit
290+
def _make_bnds_bad_xmin(xmin):
291+
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
292+
bnds.xmin = xmin
293+
return bnds
294+
295+
@jax.jit
296+
def _make_bnds_bad_ymin(ymin):
297+
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
298+
bnds.ymin = ymin
299+
return bnds
300+
301+
with pytest.raises(Exception):
302+
_make_bnds_bad_xmin(2)
303+
304+
with pytest.raises(Exception):
305+
_make_bnds_bad_ymin(2)
306+
307+
@jax.jit
308+
def _make_bnds_bad_deltax(deltax):
309+
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
310+
bnds.deltax = deltax
311+
return bnds
312+
313+
@jax.jit
314+
def _make_bnds_bad_deltay(deltay):
315+
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
316+
bnds.deltay = deltay
317+
return bnds
318+
319+
with pytest.raises(Exception):
320+
_make_bnds_bad_deltay(2)
321+
322+
with pytest.raises(Exception):
323+
_make_bnds_bad_deltay(2)
324+
325+
326+
def test_bounds_jax_int_set_jit():
327+
@jax.jit
328+
def _make_bnds_set_xmin(xmin):
329+
bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11)
330+
bnds.xmin = xmin
331+
return bnds
332+
333+
@jax.jit
334+
def _make_bnds_set_ymin(ymin):
335+
bnds = jax_galsim.BoundsI(xmin=jnp.array(1), ymin=1, deltax=10, deltay=11)
336+
bnds.ymin = ymin
337+
return bnds
338+
339+
bnds = _make_bnds_set_xmin(2)
340+
assert isinstance(bnds.ymin, jnp.ndarray)
341+
assert bnds.xmin == 2
342+
assert isinstance(bnds.xmin, jnp.ndarray)
343+
344+
bnds = _make_bnds_set_ymin(2)
345+
assert isinstance(bnds.ymin, jnp.ndarray)
346+
assert bnds.ymin == 2
347+
assert isinstance(bnds.ymin, jnp.ndarray)
348+
349+
@partial(jax.jit, static_argnames=("xmin",))
350+
def _make_bnds_set_xmin_static(xmin):
351+
bnds = jax_galsim.BoundsI(xmin=1, ymin=1, deltax=10, deltay=11)
352+
bnds.xmin = xmin
353+
return bnds
354+
355+
bnds = _make_bnds_set_xmin_static(2)
356+
assert isinstance(bnds.xmin, int)
357+
assert isinstance(bnds.ymin, int)
358+
assert bnds.xmin == 2

0 commit comments

Comments
 (0)