Skip to content

Commit b14177c

Browse files
authored
overhaul metrics (#951)
* moved dist_param_groups to VariationalGrid * now Encoder takes a VariationalGridMaker instance * rename 'layers' to 'detections' * forgot to add variational_grid.py * renaming. VariationalGridMaker -> VariationalDistSpec; pred -> factor * require magnitudes to match too in CatalogMetrics * remove tile matching; improve metrics * tests passing with some hacky stuff * update does it all * only compute galsim param error if there are galaxies * fixed tests * remove sklearn dependence in metrics * don't couple metrics and vardist just to access GALSIM_NAMES * manage metrics manually * remove gal_fp and star_fp * f1 -> detection_f1 * computing recall and precision per magnitude bin * MetricCollection * added plotting routine to show detection performance binned by magnitude * exclude last magnitude bin
1 parent 20ba5fc commit b14177c

File tree

11 files changed

+584
-576
lines changed

11 files changed

+584
-576
lines changed

bliss/catalog.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
from torch import Tensor
1313

1414

15+
def convert_mag_to_nmgy(mag):
16+
return 10 ** ((22.5 - mag) / 2.5)
17+
18+
19+
def convert_nmgy_to_mag(nmgy):
20+
return 22.5 - 2.5 * torch.log10(nmgy)
21+
22+
1523
class SourceType(IntEnum):
1624
STAR = 0
1725
GALAXY = 1
@@ -112,6 +120,20 @@ def galaxy_bools(self) -> Tensor:
112120
is_galaxy = self["source_type"] == SourceType.GALAXY
113121
return is_galaxy * self.is_on_mask.unsqueeze(-1)
114122

123+
@property
124+
def on_fluxes(self):
125+
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
126+
127+
Returns:
128+
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
129+
"""
130+
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
131+
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
132+
133+
@property
134+
def magnitudes(self):
135+
return convert_nmgy_to_mag(self.on_fluxes)
136+
115137
@property
116138
def device(self):
117139
return self.locs.device
@@ -242,18 +264,9 @@ def gather_param_at_tiles(self, param_name: str, indices: Tensor) -> Tensor:
242264
idx_to_gather = repeat(indices, "... -> ... k", k=param.size(-1))
243265
return torch.gather(param, dim=1, index=idx_to_gather)
244266

245-
def get_fluxes_of_on_sources(self):
246-
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
247-
248-
Returns:
249-
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
250-
"""
251-
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
252-
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
253-
254267
def _sort_sources_by_flux(self, band=2):
255268
# sort by fluxes of "on" sources to get brightest source per tile
256-
on_fluxes = self.get_fluxes_of_on_sources()[..., band] # shape n x nth x ntw x d
269+
on_fluxes = self.on_fluxes[..., band] # shape n x nth x ntw x d
257270
top_indexes = on_fluxes.argsort(dim=3, descending=True)
258271

259272
d = {"n_sources": self.n_sources}
@@ -310,7 +323,7 @@ def filter_tile_catalog_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
310323
sorted_self = self._sort_sources_by_flux(band=band)
311324

312325
# get fluxes of "on" sources to mask by
313-
on_fluxes = sorted_self.get_fluxes_of_on_sources()[..., band]
326+
on_fluxes = sorted_self.on_fluxes[..., band]
314327
flux_mask = (on_fluxes > min_flux) & (on_fluxes < max_flux)
315328

316329
d = {}
@@ -443,25 +456,42 @@ def to(self, device):
443456
def device(self):
444457
return self.plocs.device
445458

446-
def get_is_on_mask(self) -> Tensor:
459+
@property
460+
def is_on_mask(self) -> Tensor:
447461
arange = torch.arange(self.max_sources, device=self.device)
448462
return arange.view(1, -1) < self.n_sources.view(-1, 1)
449463

450464
@property
451465
def star_bools(self) -> Tensor:
452466
is_star = self["source_type"] == SourceType.STAR
453467
assert is_star.size(1) == self.max_sources
454-
is_on_mask = self.get_is_on_mask()
455468
assert is_star.size(2) == 1
456-
return is_star * is_on_mask.unsqueeze(2)
469+
return is_star * self.is_on_mask.unsqueeze(2)
457470

458471
@property
459472
def galaxy_bools(self) -> Tensor:
460473
is_galaxy = self["source_type"] == SourceType.GALAXY
461474
assert is_galaxy.size(1) == self.max_sources
462-
is_on_mask = self.get_is_on_mask()
463475
assert is_galaxy.size(2) == 1
464-
return is_galaxy * is_on_mask.unsqueeze(2)
476+
return is_galaxy * self.is_on_mask.unsqueeze(2)
477+
478+
@property
479+
def on_fluxes(self) -> Tensor:
480+
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
481+
482+
Returns:
483+
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
484+
"""
485+
if "fluxes" in self:
486+
# ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes
487+
fluxes = self.get("fluxes")
488+
else:
489+
fluxes = torch.where(self.galaxy_bools, self["galaxy_fluxes"], self["star_fluxes"])
490+
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
491+
492+
@property
493+
def magnitudes(self):
494+
return convert_nmgy_to_mag(self.on_fluxes)
465495

466496
def one_source(self, b: int, s: int):
467497
"""Return a dict containing all parameter for one specified light source."""
@@ -613,12 +643,12 @@ def to_astropy_table(self, encoder_survey_bands: Tuple[str]) -> Table:
613643

614644
# Convert dictionary of tensors to list of dictionaries
615645
on_vals = {}
616-
is_on_mask = self.get_is_on_mask()
646+
is_on_mask = self.is_on_mask
617647
for k, v in self.to_dict().items():
618648
if k == "n_sources":
619649
continue
620650
if k == "galaxy_params":
621-
# reshape get_is_on_mask() to have same last dimension as galaxy_params
651+
# reshape is_on_mask to have same last dimension as galaxy_params
622652
galaxy_params_mask = is_on_mask.unsqueeze(-1).expand_as(v)
623653
on_vals[k] = v[galaxy_params_mask].reshape(-1, v.shape[-1]).cpu()
624654
else:

bliss/conf/base_config.yaml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ encoder:
7575
scheduler_params:
7676
milestones: [32]
7777
gamma: 0.1
78-
metrics:
79-
_target_: bliss.encoder.metrics.CatalogMetrics
80-
slack: 1.0
81-
mode: matching
82-
survey_bands: ${encoder.survey_bands}
8378
image_normalizer:
8479
_target_: bliss.encoder.image_normalizer.ImageNormalizer
8580
bands: [0, 1, 2, 3, 4]
@@ -90,9 +85,27 @@ encoder:
9085
log_transform_stdevs: [-3, 0, 1, 3]
9186
use_clahe: true
9287
clahe_min_stdev: 200
88+
vd_spec:
89+
_target_: bliss.encoder.variational_dist.VariationalDistSpec
90+
survey_bands: ${encoder.survey_bands}
91+
tile_slen: ${encoder.tile_slen}
92+
matcher:
93+
_target_: bliss.encoder.metrics.CatalogMatcher
94+
dist_slack: 1.0
95+
mag_slack: null
96+
mag_band: 2 # SDSS r-band
97+
metrics:
98+
_target_: torchmetrics.MetricCollection
99+
metrics:
100+
- _target_: bliss.encoder.metrics.DetectionPerformance
101+
mag_bin_cutoffs: [19, 19.4, 19.8, 20.2, 20.6, 21, 21.4, 21.8]
102+
- _target_: bliss.encoder.metrics.SourceTypeAccuracy
103+
- _target_: bliss.encoder.metrics.FluxError
104+
survey_bands: ${encoder.survey_bands}
105+
- _target_: bliss.encoder.metrics.GalaxyShapeError
93106
do_data_augmentation: false
94107
compile_model: false # if true, compile model for potential performance
95-
two_layers: false
108+
double_detect: false
96109

97110
surveys:
98111
sdss:

0 commit comments

Comments
 (0)