Skip to content

Commit

Permalink
Merge branch 'sc/lensing_refactor_catalog_n_sources' of github.com:pr…
Browse files Browse the repository at this point in the history
…ob-ml/bliss into sc/lensing_refactor_catalog_n_sources
  • Loading branch information
shreyasc30 committed Aug 19, 2024
2 parents 289ff33 + e175b0d commit 881a110
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 1,506 deletions.
5 changes: 1 addition & 4 deletions bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ def __getitem__(self, index):
self.buffered_file_index = converted_index
with open(self.file_paths[converted_index], "rb") as f:
self.buffered_data = torch.load(f)
try:
output_data = self.buffered_data[converted_sub_index]
except KeyError:
output_data = self.buffered_data
output_data = self.buffered_data[converted_sub_index]
return self.transform(output_data)

def get_chunked_indices(self):
Expand Down
4 changes: 2 additions & 2 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ def __init__(self, height: int, width: int, d: Dict[str, Tensor]) -> None:
self.batch_size, self.max_sources, hw = d["plocs"].shape
assert hw == 2
if "n_sources" in d:
assert d["n_sources"].max().int().item() <= self.max_sources
assert d["n_sources"].shape == (self.batch_size,)
assert d.get("n_sources").max().int().item() <= self.max_sources
assert d.get("n_sources").shape == (self.batch_size,)

super().__init__(**d)

Expand Down
15 changes: 15 additions & 0 deletions bliss/simulator/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
survey: Survey,
with_dither: bool = True,
with_noise: bool = True,
faint_flux_threshold: float = None,
faint_folding_threshold: float = None,
) -> None:
"""Construct a decoder for a set of images.
Expand All @@ -25,6 +27,8 @@ def __init__(
survey: survey to mimic (psf, background, calibration, etc.)
with_dither: if True, apply random pixel shifts to the images and align them
with_noise: if True, add Poisson noise to the image pixels
faint_flux_threshold: threshold for flux to be considered dim
faint_folding_threshold: folding threshold for dim sources
"""

super().__init__()
Expand All @@ -33,6 +37,8 @@ def __init__(
self.survey = survey
self.with_dither = with_dither
self.with_noise = with_noise
self.faint_flux_threshold = faint_flux_threshold
self.faint_folding_threshold = faint_folding_threshold

survey.prepare_data()

Expand Down Expand Up @@ -181,6 +187,15 @@ def render_image(self, tile_cat):

# essentially all the runtime of the simulator is incurred by this call
# to drawImage

if source_type:
source_flux = source_params["galaxy_fluxes"][band]
else:
source_flux = source_params["star_fluxes"][band]

if self.faint_flux_threshold and source_flux.item() <= self.faint_flux_threshold:
galsim_obj.folding_threshold = self.faint_folding_threshold

galsim_obj.drawImage(
offset=offset,
method=getattr(self.survey.psf, "psf_draw_method", "auto"),
Expand Down
4 changes: 0 additions & 4 deletions case_studies/weak_lensing/lensing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ def _compute_loss(self, batch, logging_name):
batch_size, _, _, _ = batch["images"].shape[0:4]

target_cat = BaseTileCatalog(batch["tile_catalog"])
print("shear", target_cat["shear"].shape)
print("convergence", target_cat["convergence"].shape)
print("ellipticity", target_cat["ellip_lensed"].shape)
print("shear nans: ", torch.isnan(target_cat["shear"]).any())

# multiple image normalizers
input_lst = [inorm.get_input_tensor(batch) for inorm in self.image_normalizers]
Expand Down
Loading

0 comments on commit 881a110

Please sign in to comment.