Skip to content

Commit e841597

Browse files
committed
fix: do not use laxy properties or workspaces
1 parent 3c2760b commit e841597

File tree

6 files changed

+74
-114
lines changed

6 files changed

+74
-114
lines changed

jax_galsim/exponential.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from functools import lru_cache
2+
13
import galsim as _galsim
24
import jax.numpy as jnp
5+
import numpy as np
36
from jax.tree_util import register_pytree_node_class
47

58
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
69
from jax_galsim.core.utils import ensure_hashable, implements
710
from jax_galsim.gsobject import GSObject
811
from jax_galsim.random import UniformDeviate
9-
from jax_galsim.utilities import lazy_property
1012

1113

1214
@implements(_galsim.Exponential)
@@ -147,50 +149,14 @@ def withFlux(self, flux):
147149
scale_radius=self.scale_radius, flux=flux, gsparams=self.gsparams
148150
)
149151

150-
@lazy_property
151-
def _shoot_cdf(self):
152-
# Comments on the math here:
153-
#
154-
# We are looking to draw from a distribution that is r * exp(-r).
155-
# This distribution is the radial PDF of an Exponential profile.
156-
# The factor of r comes from the area element r * dr.
157-
#
158-
# We can compute the CDF of this distribution analytically, but we cannot
159-
# invert the CDF in closed form. Thus we invert it numerically using a table.
160-
#
161-
# One final detail is that we want the inversion to be accurate and are using
162-
# linear interpolation. Thus we use a change of variables r = -ln(1 - u)
163-
# to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf).
164-
#
165-
# Putting this all together, we get
166-
#
167-
# r * exp(-r) dr = -ln(1-u) (1-u) dr/du du
168-
# = -ln(1-u) (1-u) * 1 / (1-u)
169-
# = -ln(1-u)
170-
#
171-
# The new range of integration is u = 0 to u = 1. Thus the CDF is
172-
#
173-
# CDF = -int_0^u ln(1-u') du'
174-
# = u - (u - 1) ln(1 - u)
175-
#
176-
# The final detail is that galsim defines a shoot accuracy and draws photons
177-
# between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to
178-
# its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically.
179-
_rmax = -jnp.log(self.gsparams.shoot_accuracy)
180-
_umax = 1.0 - jnp.exp(-_rmax)
181-
_u_cdf = jnp.linspace(0, _umax, 10000)
182-
_cdf = _u_cdf - (_u_cdf - 1) * jnp.log(1 - _u_cdf)
183-
_cdf /= _cdf[-1]
184-
return _u_cdf, _cdf
185-
186152
@implements(_galsim.Exponential._shoot)
187153
def _shoot(self, photons, rng):
188154
ud = UniformDeviate(rng)
189155

190156
u = ud.generate(
191157
photons.x
192158
) # this does not fill arrays like in galsim so is safe
193-
_u_cdf, _cdf = self._shoot_cdf
159+
_u_cdf, _cdf = _shoot_cdf(self.gsparams.shoot_accuracy)
194160
# this interpolation inverts the CDF
195161
u = jnp.interp(u, _cdf, _u_cdf)
196162
# this converts from u (see above) to r and scales by the actual size of
@@ -203,3 +169,42 @@ def _shoot(self, photons, rng):
203169
photons.x = r * jnp.cos(ang)
204170
photons.y = r * jnp.sin(ang)
205171
photons.flux = self.flux / photons.size()
172+
173+
174+
@lru_cache(maxsize=8)
175+
def _shoot_cdf(shoot_accuracy):
176+
"""This routine produces a CPU-side cache of the CDF that is embedded
177+
into JIT-compiled code as needed."""
178+
# Comments on the math here:
179+
#
180+
# We are looking to draw from a distribution that is r * exp(-r).
181+
# This distribution is the radial PDF of an Exponential profile.
182+
# The factor of r comes from the area element r * dr.
183+
#
184+
# We can compute the CDF of this distribution analytically, but we cannot
185+
# invert the CDF in closed form. Thus we invert it numerically using a table.
186+
#
187+
# One final detail is that we want the inversion to be accurate and are using
188+
# linear interpolation. Thus we use a change of variables r = -ln(1 - u)
189+
# to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf).
190+
#
191+
# Putting this all together, we get
192+
#
193+
# r * exp(-r) dr = -ln(1-u) (1-u) dr/du du
194+
# = -ln(1-u) (1-u) * 1 / (1-u)
195+
# = -ln(1-u)
196+
#
197+
# The new range of integration is u = 0 to u = 1. Thus the CDF is
198+
#
199+
# CDF = -int_0^u ln(1-u') du'
200+
# = u - (u - 1) ln(1 - u)
201+
#
202+
# The final detail is that galsim defines a shoot accuracy and draws photons
203+
# between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to
204+
# its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically.
205+
_rmax = -np.log(shoot_accuracy)
206+
_umax = 1.0 - np.exp(-_rmax)
207+
_u_cdf = np.linspace(0, _umax, 10000)
208+
_cdf = _u_cdf - (_u_cdf - 1) * np.log(1 - _u_cdf)
209+
_cdf /= _cdf[-1]
210+
return _u_cdf, _cdf

jax_galsim/gsobject.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class GSObject:
2929
def __init__(self, *, gsparams=None, **params):
3030
self._params = params # Dictionary containing all traced parameters
3131
self._gsparams = GSParams.check(gsparams) # Non-traced static parameters
32-
self._workspace = {} # used by lazy_property
3332

3433
def __getstate__(self):
3534
d = self.__dict__.copy()

jax_galsim/interpolant.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from jax_galsim.errors import GalSimError
2121
from jax_galsim.gsparams import GSParams
2222
from jax_galsim.random import UniformDeviate
23-
from jax_galsim.utilities import lazy_property
2423

2524

2625
@implements(_galsim.interpolant.Interpolant)
@@ -225,7 +224,7 @@ def urange(self):
225224
% self.__class__.__name__
226225
)
227226

228-
@lazy_property
227+
# TODO: Work out CPU-side caching and pre-generation for this
229228
def _shoot_cdf(self):
230229
x = jnp.linspace(-self.xrange, self.xrange, 10000)
231230
px = jnp.abs(self._xval_noraise(jnp.abs(x)))
@@ -1389,7 +1388,7 @@ def _du(self):
13891388
/ self._n
13901389
)
13911390

1392-
@lazy_property
1391+
@property
13931392
def _umax(self):
13941393
return _find_umax_lanczos(
13951394
self._du,

jax_galsim/interpolatedimage.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from jax_galsim.position import PositionD
3232
from jax_galsim.random import UniformDeviate
3333
from jax_galsim.transform import Transformation
34-
from jax_galsim.utilities import convert_interpolant, lazy_property
34+
from jax_galsim.utilities import convert_interpolant
3535
from jax_galsim.wcs import BaseWCS, PixelScale
3636

3737
# These keys are removed from the public API of
@@ -548,7 +548,7 @@ def k_interpolant(self):
548548
"""The Fourier-space `Interpolant` for this profile."""
549549
return self._k_interpolant
550550

551-
@lazy_property
551+
@property
552552
def image(self):
553553
"""The underlying `Image` being interpolated."""
554554
return self._xim[self._image.bounds]
@@ -557,19 +557,19 @@ def image(self):
557557
def _flux(self):
558558
return self._image_flux
559559

560-
@lazy_property
560+
@property
561561
def _centroid(self):
562562
x, y = self._pad_image.get_pixel_centers()
563563
tot = jnp.sum(self._pad_image.array)
564564
xpos = jnp.sum(x * self._pad_image.array) / tot
565565
ypos = jnp.sum(y * self._pad_image.array) / tot
566566
return PositionD(xpos, ypos)
567567

568-
@lazy_property
568+
@property
569569
def _max_sb(self):
570570
return jnp.max(jnp.abs(self._pad_image.array))
571571

572-
@lazy_property
572+
@property
573573
def _flux_ratio(self):
574574
if self._jax_children[1]["flux"] is None:
575575
flux = self._image_flux
@@ -585,11 +585,11 @@ def _flux_ratio(self):
585585
# this class
586586
return flux / self._image_flux
587587

588-
@lazy_property
588+
@property
589589
def _image_flux(self):
590590
return jnp.sum(self._image.array, dtype=float)
591591

592-
@lazy_property
592+
@property
593593
def _offset(self):
594594
# Figure out the offset to apply based on the original image (not the padded one).
595595
# We will apply this below in _sbp.
@@ -598,7 +598,7 @@ def _offset(self):
598598
self._image.bounds, offset, None, self._jax_aux_data["use_true_center"]
599599
)
600600

601-
@lazy_property
601+
@property
602602
def _image(self):
603603
# Store the image as an attribute and make sure we don't change the original image
604604
# in anything we do here. (e.g. set scale, etc.)
@@ -616,7 +616,7 @@ def _image(self):
616616

617617
return image
618618

619-
@lazy_property
619+
@property
620620
def _wcs(self):
621621
im_cen = (
622622
self._jax_children[0].true_center
@@ -634,15 +634,15 @@ def _wcs(self):
634634

635635
return wcs.local(image_pos=im_cen)
636636

637-
@lazy_property
637+
@property
638638
def _jac_arr(self):
639639
image = self._jax_children[0]
640640
im_cen = (
641641
image.true_center if self._jax_aux_data["use_true_center"] else image.center
642642
)
643643
return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel()
644644

645-
@lazy_property
645+
@property
646646
def _xim(self):
647647
pad_factor = self._jax_aux_data["pad_factor"]
648648

@@ -669,27 +669,27 @@ def _xim(self):
669669

670670
return xim
671671

672-
@lazy_property
672+
@property
673673
def _pad_image(self):
674674
# These next two allow for easy pickling/repring. We don't need to serialize all the
675675
# zeros around the edge. But we do need to keep any non-zero padding as a pad_image.
676676
xim = self._xim
677677
nz_bounds = self._image.bounds
678678
return xim[nz_bounds]
679679

680-
@lazy_property
680+
@property
681681
def _kim(self):
682682
return self._xim.calculate_fft()
683683

684-
@lazy_property
684+
@property
685685
def _maxk(self):
686686
if self._jax_aux_data["_force_maxk"]:
687687
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
688688
return self._jax_aux_data["_force_maxk"] * minor
689689
else:
690690
return self._getMaxK(self._jax_aux_data["calculate_maxk"])
691691

692-
@lazy_property
692+
@property
693693
def _stepk(self):
694694
if self._jax_aux_data["_force_stepk"]:
695695
_, minor = compute_major_minor_from_jacobian(self._jac_arr.reshape((2, 2)))
@@ -837,7 +837,7 @@ def _drawKImage(self, image, jac=None):
837837
# Return an image
838838
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
839839

840-
@lazy_property
840+
@property
841841
def _pos_neg_fluxes(self):
842842
# record pos and neg fluxes now too
843843
pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0))

jax_galsim/spergel.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements
1010
from jax_galsim.gsobject import GSObject
1111
from jax_galsim.random import UniformDeviate
12-
from jax_galsim.utilities import lazy_property
1312

1413

1514
@jax.jit
@@ -321,34 +320,34 @@ def scale_radius(self):
321320
def _r0(self):
322321
return self.scale_radius
323322

324-
@lazy_property
323+
@property
325324
def _inv_r0(self):
326325
return 1.0 / self._r0
327326

328-
@lazy_property
327+
@property
329328
def _r0_sq(self):
330329
return self._r0 * self._r0
331330

332-
@lazy_property
331+
@property
333332
def _inv_r0_sq(self):
334333
return self._inv_r0 * self._inv_r0
335334

336-
@lazy_property
335+
@property
337336
@implements(_galsim.spergel.Spergel.half_light_radius)
338337
def half_light_radius(self):
339338
return self._r0 * _spergel_hlr_pade(self.nu)
340339

341-
@lazy_property
340+
@property
342341
def _shootxnorm(self):
343342
"""Normalization for photon shooting"""
344343
return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu))
345344

346-
@lazy_property
345+
@property
347346
def _xnorm(self):
348347
"""Normalization of xValue"""
349348
return self._shootxnorm * self.flux * self._inv_r0_sq
350349

351-
@lazy_property
350+
@property
352351
def _xnorm0(self):
353352
"""return z^nu K_nu(z) for z=0"""
354353
return jax.lax.select(
@@ -392,21 +391,21 @@ def __str__(self):
392391
s += ")"
393392
return s
394393

395-
@lazy_property
394+
@property
396395
def _maxk(self):
397396
"""(1+ (k r0)^2)^(-1-nu) = maxk_threshold"""
398397
res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0
399398
return jnp.sqrt(res) / self._r0
400399

401-
@lazy_property
400+
@property
402401
def _stepk(self):
403402
R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu)
404403
R *= self._r0
405404
# Go to at least 5*hlr
406405
R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius)
407406
return jnp.pi / R
408407

409-
@lazy_property
408+
@property
410409
def _max_sb(self):
411410
# from SBSpergelImpl.h
412411
return jnp.abs(self._xnorm) * self._xnorm0
@@ -439,7 +438,7 @@ def withFlux(self, flux):
439438
gsparams=self.gsparams,
440439
)
441440

442-
@lazy_property
441+
@property
443442
def _shoot_pos_cdf(self):
444443
zmax = calculateFluxRadius(
445444
1.0 - self.gsparams.shoot_accuracy, self.nu, zmax=30.0
@@ -459,7 +458,7 @@ def _shoot_pos(self, u):
459458
r = z * self._r0
460459
return r
461460

462-
@lazy_property
461+
@property
463462
def _shoot_neg_cdf(self):
464463
# comment:
465464
# In the Galsim code the profile below rmin is linearized such that

0 commit comments

Comments
 (0)