Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image *= obj._drawKImage(image, jac)
image._array = image._array.at[...].multiply(
obj._drawKImage(image, jac)._array
)
return image

def tree_flatten(self):
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,7 @@ def _draw_phot_while_loop_shoot(

im1 = ImageD(bounds=image.bounds)
added_flux += sensor.accumulate(photons, im1, orig_center)
image += im1
image._array = image._array.at[...].add(im1)

return _DrawPhotReturnTuple(
photons, rng, added_flux, image, photon_ops, sensor, resume
Expand Down
18 changes: 9 additions & 9 deletions jax_galsim/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def calculate_fft(self):
jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0
)

out *= dx * dx
out._array = out._array.at[...].multiply(dx * dx)
out.setOrigin(0, -No2)
return out

Expand Down Expand Up @@ -769,7 +769,7 @@ def calculate_inverse_fft(self):
)
# Now cut off the bit we don't need.
out = out_extra.subImage(BoundsI(-No2, No2 - 1, -No2, No2 - 1))
out *= (dk * No2 / jnp.pi) ** 2
out._array = out._array.at[:].multiply((dk * No2 / jnp.pi) ** 2)
out.setCenter(0, 0)
return out

Expand Down Expand Up @@ -1269,7 +1269,7 @@ def Image_iadd(self, other):
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array + a
self._array = self.array.at[...].add(a)
else:
self._array = (self.array + a).astype(self.array.dtype)
return self
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def Image_isub(self, other):
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array - a
self._array = self.array.at[...].subtract(a)
else:
self._array = (self.array - a).astype(self.array.dtype)
return self
Expand All @@ -1321,7 +1321,7 @@ def Image_imul(self, other):
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array * a
self._array = self.array.at[...].multiply(a)
else:
self._array = (self.array * a).astype(self.array.dtype)
return self
Expand Down Expand Up @@ -1351,7 +1351,7 @@ def Image_idiv(self, other):
if dt == self.array.dtype and not self.isinteger:
# if dtype is an integer type, then numpy doesn't allow true division /= to assign
# back to an integer array. So for integers (or mixed types), don't use /=.
self._array = self.array / a
self._array = self.array.at[...].divide(a)
else:
self._array = (self.array / a).astype(self.array.dtype)
return self
Expand All @@ -1363,12 +1363,12 @@ def Image_floordiv(self, other):
a = other.array
except AttributeError:
a = other
return Image(self.array // a, bounds=self.bounds, wcs=self.wcs)
return Image(self._array // a, bounds=self.bounds, wcs=self.wcs)
Comment thread
beckermr marked this conversation as resolved.
Outdated


def Image_rfloordiv(self, other):
check_image_consistency(self, other, integer=True)
return Image(other // self.array, bounds=self.bounds, wcs=self.wcs)
return Image(other // self._array, bounds=self.bounds, wcs=self.wcs)
Comment thread
beckermr marked this conversation as resolved.
Outdated


def Image_ifloordiv(self, other):
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def Image_pow(self, other):
def Image_ipow(self, other):
if not isinstance(other, int) and not isinstance(other, float):
raise TypeError("Can only raise an image to a float or int power!")
self._array = self.array**other
self._array = self.array.at[...].power(other)
return self


Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def addNoiseSNR(self, noise, snr, preserve_flux=False):
else:
sn_meas = jnp.sqrt(sumsq / noise_var)
flux = snr / sn_meas
self *= flux
self._array = self._array.at[...].multiply(flux)
Comment thread
beckermr marked this conversation as resolved.
Outdated
self.addNoise(noise)
return noise_var

Expand Down
8 changes: 6 additions & 2 deletions jax_galsim/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,18 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image += obj._drawReal(image, jac, offset, flux_scaling)
image._array = image._array.at[...].add(
obj._drawReal(image, jac, offset, flux_scaling)._array
)
return image

def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image += obj._drawKImage(image, jac)
image._array = image._array.at[...].add(
obj._drawKImage(image, jac)._array
)
return image

@property
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_interpolatedimage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def _compute_fft_with_numpy_jax_galsim(im):

out = Image(BoundsI(0, No2, -No2, No2 - 1), dtype=np.complex128, scale=dk)
out._array = np.fft.fftshift(np.fft.rfft2(np.fft.fftshift(ximage.array)), axes=0)
out *= dx * dx
out._array *= dx * dx
Comment thread
beckermr marked this conversation as resolved.
Outdated
out.setOrigin(0, -No2)
return out

Expand Down
Loading