1+ from functools import lru_cache
2+
13import galsim as _galsim
24import jax .numpy as jnp
5+ import numpy as np
36from jax .tree_util import register_pytree_node_class
47
58from jax_galsim .core .draw import draw_by_kValue , draw_by_xValue
69from jax_galsim .core .utils import ensure_hashable , implements
710from jax_galsim .gsobject import GSObject
811from jax_galsim .random import UniformDeviate
9- from jax_galsim .utilities import lazy_property
1012
1113
1214@implements (_galsim .Exponential )
@@ -147,50 +149,14 @@ def withFlux(self, flux):
147149 scale_radius = self .scale_radius , flux = flux , gsparams = self .gsparams
148150 )
149151
150- @lazy_property
151- def _shoot_cdf (self ):
152- # Comments on the math here:
153- #
154- # We are looking to draw from a distribution that is r * exp(-r).
155- # This distribution is the radial PDF of an Exponential profile.
156- # The factor of r comes from the area element r * dr.
157- #
158- # We can compute the CDF of this distribution analytically, but we cannot
159- # invert the CDF in closed form. Thus we invert it numerically using a table.
160- #
161- # One final detail is that we want the inversion to be accurate and are using
162- # linear interpolation. Thus we use a change of variables r = -ln(1 - u)
163- # to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf).
164- #
165- # Putting this all together, we get
166- #
167- # r * exp(-r) dr = -ln(1-u) (1-u) dr/du du
168- # = -ln(1-u) (1-u) * 1 / (1-u)
169- # = -ln(1-u)
170- #
171- # The new range of integration is u = 0 to u = 1. Thus the CDF is
172- #
173- # CDF = -int_0^u ln(1-u') du'
174- # = u - (u - 1) ln(1 - u)
175- #
176- # The final detail is that galsim defines a shoot accuracy and draws photons
177- # between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to
178- # its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically.
179- _rmax = - jnp .log (self .gsparams .shoot_accuracy )
180- _umax = 1.0 - jnp .exp (- _rmax )
181- _u_cdf = jnp .linspace (0 , _umax , 10000 )
182- _cdf = _u_cdf - (_u_cdf - 1 ) * jnp .log (1 - _u_cdf )
183- _cdf /= _cdf [- 1 ]
184- return _u_cdf , _cdf
185-
186152 @implements (_galsim .Exponential ._shoot )
187153 def _shoot (self , photons , rng ):
188154 ud = UniformDeviate (rng )
189155
190156 u = ud .generate (
191157 photons .x
192158 ) # this does not fill arrays like in galsim so is safe
193- _u_cdf , _cdf = self ._shoot_cdf
159+ _u_cdf , _cdf = _shoot_cdf ( self .gsparams . shoot_accuracy )
194160 # this interpolation inverts the CDF
195161 u = jnp .interp (u , _cdf , _u_cdf )
196162 # this converts from u (see above) to r and scales by the actual size of
@@ -203,3 +169,42 @@ def _shoot(self, photons, rng):
203169 photons .x = r * jnp .cos (ang )
204170 photons .y = r * jnp .sin (ang )
205171 photons .flux = self .flux / photons .size ()
172+
173+
174+ @lru_cache (maxsize = 8 )
175+ def _shoot_cdf (shoot_accuracy ):
176+ """This routine produces a CPU-side cache of the CDF that is embedded
177+ into JIT-compiled code as needed."""
178+ # Comments on the math here:
179+ #
180+ # We are looking to draw from a distribution that is r * exp(-r).
181+ # This distribution is the radial PDF of an Exponential profile.
182+ # The factor of r comes from the area element r * dr.
183+ #
184+ # We can compute the CDF of this distribution analytically, but we cannot
185+ # invert the CDF in closed form. Thus we invert it numerically using a table.
186+ #
187+ # One final detail is that we want the inversion to be accurate and are using
188+ # linear interpolation. Thus we use a change of variables r = -ln(1 - u)
189+ # to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf).
190+ #
191+ # Putting this all together, we get
192+ #
193+ # r * exp(-r) dr = -ln(1-u) (1-u) dr/du du
194+ # = -ln(1-u) (1-u) * 1 / (1-u)
195+ # = -ln(1-u)
196+ #
197+ # The new range of integration is u = 0 to u = 1. Thus the CDF is
198+ #
199+ # CDF = -int_0^u ln(1-u') du'
200+ # = u - (u - 1) ln(1 - u)
201+ #
202+ # The final detail is that galsim defines a shoot accuracy and draws photons
203+ # between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to
204+ # its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically.
205+ _rmax = - np .log (shoot_accuracy )
206+ _umax = 1.0 - np .exp (- _rmax )
207+ _u_cdf = np .linspace (0 , _umax , 10000 )
208+ _cdf = _u_cdf - (_u_cdf - 1 ) * np .log (1 - _u_cdf )
209+ _cdf /= _cdf [- 1 ]
210+ return _u_cdf , _cdf
0 commit comments