Skip to content

Commit

Permalink
overhaul metrics (#951)
Browse files Browse the repository at this point in the history
* moved dist_param_groups to VariationalGrid

* now Encoder takes a VariationalGridMaker instance

* rename 'layers' to 'detections'

* forgot to add variational_grid.py

* renaming. VariationalGridMaker -> VariationalDistSpec; pred -> factor

* require magnitudes to match too in CatalogMetrics

* remove tile matching; improve metrics

* tests passing with some hacky stuff

* update does it all

* only compute galsim param error if there are galaxies

* fixed tests

* remove sklearn dependence in metrics

* don't couple metrics and vardist just to access GALSIM_NAMES

* manage metrics manually

* remove gal_fp and star_fp

* f1 -> detection_f1

* computing recall and precision per magnitude bin

* MetricCollection

* added plotting routine to show detection performance binned by magnitude

* exclude last magnitude bin
  • Loading branch information
jeff-regier authored Dec 17, 2023
1 parent 20ba5fc commit b14177c
Show file tree
Hide file tree
Showing 11 changed files with 584 additions and 576 deletions.
66 changes: 48 additions & 18 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
from torch import Tensor


def convert_mag_to_nmgy(mag):
return 10 ** ((22.5 - mag) / 2.5)


def convert_nmgy_to_mag(nmgy):
return 22.5 - 2.5 * torch.log10(nmgy)


class SourceType(IntEnum):
STAR = 0
GALAXY = 1
Expand Down Expand Up @@ -112,6 +120,20 @@ def galaxy_bools(self) -> Tensor:
is_galaxy = self["source_type"] == SourceType.GALAXY
return is_galaxy * self.is_on_mask.unsqueeze(-1)

@property
def on_fluxes(self):
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
Returns:
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
"""
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))

@property
def magnitudes(self):
return convert_nmgy_to_mag(self.on_fluxes)

@property
def device(self):
return self.locs.device
Expand Down Expand Up @@ -242,18 +264,9 @@ def gather_param_at_tiles(self, param_name: str, indices: Tensor) -> Tensor:
idx_to_gather = repeat(indices, "... -> ... k", k=param.size(-1))
return torch.gather(param, dim=1, index=idx_to_gather)

def get_fluxes_of_on_sources(self):
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
Returns:
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
"""
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))

def _sort_sources_by_flux(self, band=2):
# sort by fluxes of "on" sources to get brightest source per tile
on_fluxes = self.get_fluxes_of_on_sources()[..., band] # shape n x nth x ntw x d
on_fluxes = self.on_fluxes[..., band] # shape n x nth x ntw x d
top_indexes = on_fluxes.argsort(dim=3, descending=True)

d = {"n_sources": self.n_sources}
Expand Down Expand Up @@ -310,7 +323,7 @@ def filter_tile_catalog_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
sorted_self = self._sort_sources_by_flux(band=band)

# get fluxes of "on" sources to mask by
on_fluxes = sorted_self.get_fluxes_of_on_sources()[..., band]
on_fluxes = sorted_self.on_fluxes[..., band]
flux_mask = (on_fluxes > min_flux) & (on_fluxes < max_flux)

d = {}
Expand Down Expand Up @@ -443,25 +456,42 @@ def to(self, device):
def device(self):
return self.plocs.device

def get_is_on_mask(self) -> Tensor:
@property
def is_on_mask(self) -> Tensor:
arange = torch.arange(self.max_sources, device=self.device)
return arange.view(1, -1) < self.n_sources.view(-1, 1)

@property
def star_bools(self) -> Tensor:
is_star = self["source_type"] == SourceType.STAR
assert is_star.size(1) == self.max_sources
is_on_mask = self.get_is_on_mask()
assert is_star.size(2) == 1
return is_star * is_on_mask.unsqueeze(2)
return is_star * self.is_on_mask.unsqueeze(2)

@property
def galaxy_bools(self) -> Tensor:
is_galaxy = self["source_type"] == SourceType.GALAXY
assert is_galaxy.size(1) == self.max_sources
is_on_mask = self.get_is_on_mask()
assert is_galaxy.size(2) == 1
return is_galaxy * is_on_mask.unsqueeze(2)
return is_galaxy * self.is_on_mask.unsqueeze(2)

@property
def on_fluxes(self) -> Tensor:
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
Returns:
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
"""
if "fluxes" in self:
# ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes
fluxes = self.get("fluxes")
else:
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))

@property
def magnitudes(self):
return convert_nmgy_to_mag(self.on_fluxes)

def one_source(self, b: int, s: int):
"""Return a dict containing all parameter for one specified light source."""
Expand Down Expand Up @@ -613,12 +643,12 @@ def to_astropy_table(self, encoder_survey_bands: Tuple[str]) -> Table:

# Convert dictionary of tensors to list of dictionaries
on_vals = {}
is_on_mask = self.get_is_on_mask()
is_on_mask = self.is_on_mask
for k, v in self.to_dict().items():
if k == "n_sources":
continue
if k == "galaxy_params":
# reshape get_is_on_mask() to have same last dimension as galaxy_params
# reshape is_on_mask to have same last dimension as galaxy_params
galaxy_params_mask = is_on_mask.unsqueeze(-1).expand_as(v)
on_vals[k] = v[galaxy_params_mask].reshape(-1, v.shape[-1]).cpu()
else:
Expand Down
25 changes: 19 additions & 6 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ encoder:
scheduler_params:
milestones: [32]
gamma: 0.1
metrics:
_target_: bliss.encoder.metrics.CatalogMetrics
slack: 1.0
mode: matching
survey_bands: ${encoder.survey_bands}
image_normalizer:
_target_: bliss.encoder.image_normalizer.ImageNormalizer
bands: [0, 1, 2, 3, 4]
Expand All @@ -90,9 +85,27 @@ encoder:
log_transform_stdevs: [-3, 0, 1, 3]
use_clahe: true
clahe_min_stdev: 200
vd_spec:
_target_: bliss.encoder.variational_dist.VariationalDistSpec
survey_bands: ${encoder.survey_bands}
tile_slen: ${encoder.tile_slen}
matcher:
_target_: bliss.encoder.metrics.CatalogMatcher
dist_slack: 1.0
mag_slack: null
mag_band: 2 # SDSS r-band
metrics:
_target_: torchmetrics.MetricCollection
metrics:
- _target_: bliss.encoder.metrics.DetectionPerformance
mag_bin_cutoffs: [19, 19.4, 19.8, 20.2, 20.6, 21, 21.4, 21.8]
- _target_: bliss.encoder.metrics.SourceTypeAccuracy
- _target_: bliss.encoder.metrics.FluxError
survey_bands: ${encoder.survey_bands}
- _target_: bliss.encoder.metrics.GalaxyShapeError
do_data_augmentation: false
compile_model: false # if true, compile model for potential performance
two_layers: false
double_detect: false

surveys:
sdss:
Expand Down
Loading

0 comments on commit b14177c

Please sign in to comment.