Skip to content

Commit

Permalink
Regular update (#1047)
Browse files Browse the repository at this point in the history
* update notebooks

* update notebooks

* update config

* update exp_record

* update MultiDetectEncoder

* update MultiDetectEncoder

* update notebooks; update SourceTypeAccuracy to report w.r.t magnitudes

* update notebooks

* add ellipticity estimation

* fix tests

* fix the warning in sampler

* add a new notebook to show the mse of ellipticity estimation

* change MSE to residual for ellipticity measure

* modify units in catalog; remove background from dc2

* replace on_fluxes, magnitudes, magnitudes_njy with on_nmgy, on_mag and on_njy

* fix notebook after merge

* change the gating logic of VariationalFactor

* delete old notebooks

* address pr comments

---------

Co-authored-by: Yicun Duan <[email protected]>
  • Loading branch information
YicunDuanUMich and Yicun Duan authored Jul 24, 2024
1 parent d84e937 commit 2cefde4
Show file tree
Hide file tree
Showing 43 changed files with 3,998 additions and 3,252 deletions.
2 changes: 1 addition & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__(dataset)
super().__init__()
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
self.dataset = dataset

Expand Down
78 changes: 54 additions & 24 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,22 @@ def convert_nmgy_to_mag(nmgy):


def convert_nmgy_to_njymag(nmgy):
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2."""
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
the "Flux units: maggies and nanomaggies" part of
https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
When we change the standard source to AB sources, we need to do the conversion
described in "2.10 AB magnitudes" at
https://pstn-001.lsst.io/fluxunits.pdf
Args:
nmgy: the fluxes in nanomaggies
Returns:
Tensor indicating fluxes in AB magnitude
"""

return 22.5 - 2.5 * torch.log10(nmgy / 3631)


Expand Down Expand Up @@ -148,13 +163,19 @@ def galaxy_bools(self) -> Tensor:
is_galaxy = self["source_type"] == SourceType.GALAXY
return is_galaxy * self.is_on_mask.unsqueeze(-1)

@property
def on_fluxes(self):
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
def on_fluxes(self, unit: str):
match unit:
case "nmgy":
return self.on_nmgy
case "mag":
return self.on_mag
case "njymag":
return self.on_njymag
case _:
raise NotImplementedError()

Returns:
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
"""
@property
def on_nmgy(self):
# TODO: a tile catalog should store fluxes rather than star_fluxes and galaxy_fluxes
# because that's all that's needed to render the source
if "galaxy_fluxes" not in self:
Expand All @@ -164,13 +185,12 @@ def on_fluxes(self):
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))

@property
def magnitudes(self):
# TODO: we shouldn't assume fluxes are stored in nanomaggies because they aren't for DC2
return convert_nmgy_to_mag(self.on_fluxes)
def on_mag(self) -> Tensor:
return convert_nmgy_to_mag(self.on_nmgy)

@property
def magnitudes_njy(self):
return convert_nmgy_to_njymag(self.on_fluxes)
def on_njymag(self) -> Tensor:
return convert_nmgy_to_njymag(self.on_nmgy)

def to_full_catalog(self, tile_slen):
"""Converts image parameters in tiles to parameters of full image.
Expand Down Expand Up @@ -266,8 +286,8 @@ def get_indices_of_on_sources(self) -> Tuple[Tensor, Tensor]:

def _sort_sources_by_flux(self, band=2):
# sort by fluxes of "on" sources to get brightest source per tile
on_fluxes = self.on_fluxes[..., band] # shape n x nth x ntw x d
top_indexes = on_fluxes.argsort(dim=3, descending=True)
on_nmgy = self.on_nmgy[..., band] # shape n x nth x ntw x d
top_indexes = on_nmgy.argsort(dim=3, descending=True)

d = {"n_sources": self["n_sources"]}
for key, val in self.items():
Expand Down Expand Up @@ -323,8 +343,8 @@ def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
sorted_self = self._sort_sources_by_flux(band=band)

# get fluxes of "on" sources to mask by
on_fluxes = sorted_self.on_fluxes[..., band]
flux_mask = (on_fluxes > min_flux) & (on_fluxes < max_flux)
on_nmgy = sorted_self.on_nmgy[..., band]
flux_mask = (on_nmgy > min_flux) & (on_nmgy < max_flux)

d = {}
for key, val in sorted_self.items():
Expand Down Expand Up @@ -458,13 +478,19 @@ def galaxy_bools(self) -> Tensor:
assert is_galaxy.size(2) == 1
return is_galaxy * self.is_on_mask.unsqueeze(2)

@property
def on_fluxes(self) -> Tensor:
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
def on_fluxes(self, unit: str):
match unit:
case "nmgy":
return self.on_nmgy
case "mag":
return self.on_mag
case "njymag":
return self.on_njymag
case _:
raise NotImplementedError()

Returns:
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
"""
@property
def on_nmgy(self) -> Tensor:
# ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes
if "galaxy_fluxes" not in self:
fluxes = self["star_fluxes"]
Expand All @@ -473,8 +499,12 @@ def on_fluxes(self) -> Tensor:
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))

@property
def magnitudes(self):
return convert_nmgy_to_mag(self.on_fluxes)
def on_mag(self) -> Tensor:
return convert_nmgy_to_mag(self.on_nmgy)

@property
def on_njymag(self) -> Tensor:
return convert_nmgy_to_njymag(self.on_nmgy)

def one_source(self, b: int, s: int):
"""Return a dict containing all parameter for one specified light source."""
Expand Down
47 changes: 32 additions & 15 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,64 +80,75 @@ variational_factors:
name: locs
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 d -> b ht wt d"
nll_gating: n_sources
nll_gating:
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.BernoulliFactor
name: source_type
sample_rearrange: "b ht wt -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt"
nll_gating: n_sources
nll_gating:
_target_: bliss.encoder.variational_dist.SourcesGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: star_fluxes
dim: 5
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 d -> b ht wt d"
nll_gating: is_star
nll_gating:
_target_: bliss.encoder.variational_dist.StarGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: galaxy_fluxes
dim: 5
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 d -> b ht wt d"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
name: galaxy_disk_frac
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
name: galaxy_beta_radians
high: 3.1415926
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
name: galaxy_disk_q
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: galaxy_a_d
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
name: galaxy_bulge_q
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating
- _target_: bliss.encoder.variational_dist.LogNormalFactor
name: galaxy_a_b
sample_rearrange: "b ht wt d -> b ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating

metrics:
detection_performance:
_target_: bliss.encoder.metrics.DetectionPerformance
mag_bin_cutoffs: [19, 19.4, 19.8, 20.2, 20.6, 21, 21.4, 21.8]
mag_band: 2
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
source_type_accuracy:
_target_: bliss.encoder.metrics.SourceTypeAccuracy
flux_bin_cutoffs: [200, 400, 600, 800, 1000]
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
flux_error:
_target_: bliss.encoder.metrics.FluxError
survey_bands: ${encoder.survey_bands}
Expand Down Expand Up @@ -230,7 +241,7 @@ surveys:
n_image_split: 50
tile_slen: 4
max_sources_per_tile: 5
min_flux: 0.0
catalog_min_r_flux: 50
prepare_data_processes_num: 4
data_in_one_cached_file: 1250
splits: 0:80/80:90/90:100
Expand All @@ -242,7 +253,13 @@ surveys:
- _target_: bliss.data_augmentation.RandomShiftTransform
tile_slen: ${surveys.dc2.tile_slen}
max_sources_per_tile: ${surveys.dc2.tile_slen}
nontrain_transforms: []
- _target_: bliss.cached_dataset.FluxFilterTransform
reference_band: 2 # r-band
min_flux: 100
nontrain_transforms:
- _target_: bliss.cached_dataset.FluxFilterTransform
reference_band: 2 # r-band
min_flux: 100


#######################################################################
Expand Down
4 changes: 2 additions & 2 deletions bliss/encoder/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
self,
survey_bands: list,
tile_slen: int,
image_normalizers: list,
image_normalizers: dict,
var_dist: VariationalDist,
matcher: CatalogMatcher,
sample_image_renders: MetricCollection,
Expand Down Expand Up @@ -105,7 +105,7 @@ def make_context(self, history_cat, history_mask, detection2=False):
)
else:
centered_locs = history_cat["locs"][..., 0, :] - 0.5
log_fluxes = (history_cat.on_fluxes.squeeze(3).sum(-1) + 1).log()
log_fluxes = (history_cat.on_nmgy.squeeze(3).sum(-1) + 1).log()
history_encoding_lst = [
history_cat["n_sources"].float(), # detection history
log_fluxes * history_cat["n_sources"], # flux history
Expand Down
Loading

0 comments on commit 2cefde4

Please sign in to comment.