Skip to content

Commit 881a110

Browse files
committed
Merge branch 'sc/lensing_refactor_catalog_n_sources' of github.com:prob-ml/bliss into sc/lensing_refactor_catalog_n_sources
2 parents 289ff33 + e175b0d commit 881a110

File tree

6 files changed

+18
-1506
lines changed

6 files changed

+18
-1506
lines changed

bliss/cached_dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,7 @@ def __getitem__(self, index):
177177
self.buffered_file_index = converted_index
178178
with open(self.file_paths[converted_index], "rb") as f:
179179
self.buffered_data = torch.load(f)
180-
try:
181-
output_data = self.buffered_data[converted_sub_index]
182-
except KeyError:
183-
output_data = self.buffered_data
180+
output_data = self.buffered_data[converted_sub_index]
184181
return self.transform(output_data)
185182

186183
def get_chunked_indices(self):

bliss/catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,8 @@ def __init__(self, height: int, width: int, d: Dict[str, Tensor]) -> None:
413413
self.batch_size, self.max_sources, hw = d["plocs"].shape
414414
assert hw == 2
415415
if "n_sources" in d:
416-
assert d["n_sources"].max().int().item() <= self.max_sources
417-
assert d["n_sources"].shape == (self.batch_size,)
416+
assert d.get("n_sources").max().int().item() <= self.max_sources
417+
assert d.get("n_sources").shape == (self.batch_size,)
418418

419419
super().__init__(**d)
420420

bliss/simulator/decoder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def __init__(
1717
survey: Survey,
1818
with_dither: bool = True,
1919
with_noise: bool = True,
20+
faint_flux_threshold: float = None,
21+
faint_folding_threshold: float = None,
2022
) -> None:
2123
"""Construct a decoder for a set of images.
2224
@@ -25,6 +27,8 @@ def __init__(
2527
survey: survey to mimic (psf, background, calibration, etc.)
2628
with_dither: if True, apply random pixel shifts to the images and align them
2729
with_noise: if True, add Poisson noise to the image pixels
30+
faint_flux_threshold: threshold for flux to be considered dim
31+
faint_folding_threshold: folding threshold for dim sources
2832
"""
2933

3034
super().__init__()
@@ -33,6 +37,8 @@ def __init__(
3337
self.survey = survey
3438
self.with_dither = with_dither
3539
self.with_noise = with_noise
40+
self.faint_flux_threshold = faint_flux_threshold
41+
self.faint_folding_threshold = faint_folding_threshold
3642

3743
survey.prepare_data()
3844

@@ -181,6 +187,15 @@ def render_image(self, tile_cat):
181187

182188
# essentially all the runtime of the simulator is incurred by this call
183189
# to drawImage
190+
191+
if source_type:
192+
source_flux = source_params["galaxy_fluxes"][band]
193+
else:
194+
source_flux = source_params["star_fluxes"][band]
195+
196+
if self.faint_flux_threshold and source_flux.item() <= self.faint_flux_threshold:
197+
galsim_obj.folding_threshold = self.faint_folding_threshold
198+
184199
galsim_obj.drawImage(
185200
offset=offset,
186201
method=getattr(self.survey.psf, "psf_draw_method", "auto"),

case_studies/weak_lensing/lensing_encoder.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,6 @@ def _compute_loss(self, batch, logging_name):
7777
batch_size, _, _, _ = batch["images"].shape[0:4]
7878

7979
target_cat = BaseTileCatalog(batch["tile_catalog"])
80-
print("shear", target_cat["shear"].shape)
81-
print("convergence", target_cat["convergence"].shape)
82-
print("ellipticity", target_cat["ellip_lensed"].shape)
83-
print("shear nans: ", torch.isnan(target_cat["shear"]).any())
8480

8581
# multiple image normalizers
8682
input_lst = [inorm.get_input_tensor(batch) for inorm in self.image_normalizers]

0 commit comments

Comments
 (0)