Skip to content

Commit 2cefde4

Browse files
YicunDuanUMichYicun Duan
andauthored
Regular update (#1047)
* update notebooks * update notebooks * update config * update exp_record * update MultiDetectEncoder * update MultiDetectEncoder * update notebooks; update SourceTypeAccuracy to report w.r.t magnitudes * update notebooks * add ellipticity estimation * fix tests * fix the warning in sampler * add a new notebook to show the mse of ellipticity estimation * change MSE to residual for ellipticity measure * modify units in catalog; remove background from dc2 * replace on_fluxes, magnitudes, magnitudes_njy with on_nmgy, on_mag and on_njy * fix notebook after merge * change the gating logic of VariationalFactor * delete old notebooks * address pr comments --------- Co-authored-by: Yicun Duan <[email protected]>
1 parent d84e937 commit 2cefde4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3998
-3252
lines changed

bliss/cached_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __call__(self, datum_in):
9090

9191
class ChunkingSampler(Sampler):
9292
def __init__(self, dataset: Dataset) -> None:
93-
super().__init__(dataset)
93+
super().__init__()
9494
assert isinstance(dataset, ChunkingDataset), "dataset should be ChunkingDataset"
9595
self.dataset = dataset
9696

bliss/catalog.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,22 @@ def convert_nmgy_to_mag(nmgy):
1919

2020

2121
def convert_nmgy_to_njymag(nmgy):
22-
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2."""
22+
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
23+
24+
For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
25+
the "Flux units: maggies and nanomaggies" part of
26+
https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
27+
When we change the standard source to AB sources, we need to do the conversion
28+
described in "2.10 AB magnitudes" at
29+
https://pstn-001.lsst.io/fluxunits.pdf
30+
31+
Args:
32+
nmgy: the fluxes in nanomaggies
33+
34+
Returns:
35+
Tensor indicating fluxes in AB magnitude
36+
"""
37+
2338
return 22.5 - 2.5 * torch.log10(nmgy / 3631)
2439

2540

@@ -148,13 +163,19 @@ def galaxy_bools(self) -> Tensor:
148163
is_galaxy = self["source_type"] == SourceType.GALAXY
149164
return is_galaxy * self.is_on_mask.unsqueeze(-1)
150165

151-
@property
152-
def on_fluxes(self):
153-
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
166+
def on_fluxes(self, unit: str):
167+
match unit:
168+
case "nmgy":
169+
return self.on_nmgy
170+
case "mag":
171+
return self.on_mag
172+
case "njymag":
173+
return self.on_njymag
174+
case _:
175+
raise NotImplementedError()
154176

155-
Returns:
156-
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
157-
"""
177+
@property
178+
def on_nmgy(self):
158179
# TODO: a tile catalog should store fluxes rather than star_fluxes and galaxy_fluxes
159180
# because that's all that's needed to render the source
160181
if "galaxy_fluxes" not in self:
@@ -164,13 +185,12 @@ def on_fluxes(self):
164185
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
165186

166187
@property
167-
def magnitudes(self):
168-
# TODO: we shouldn't assume fluxes are stored in nanomaggies because they aren't for DC2
169-
return convert_nmgy_to_mag(self.on_fluxes)
188+
def on_mag(self) -> Tensor:
189+
return convert_nmgy_to_mag(self.on_nmgy)
170190

171191
@property
172-
def magnitudes_njy(self):
173-
return convert_nmgy_to_njymag(self.on_fluxes)
192+
def on_njymag(self) -> Tensor:
193+
return convert_nmgy_to_njymag(self.on_nmgy)
174194

175195
def to_full_catalog(self, tile_slen):
176196
"""Converts image parameters in tiles to parameters of full image.
@@ -266,8 +286,8 @@ def get_indices_of_on_sources(self) -> Tuple[Tensor, Tensor]:
266286

267287
def _sort_sources_by_flux(self, band=2):
268288
# sort by fluxes of "on" sources to get brightest source per tile
269-
on_fluxes = self.on_fluxes[..., band] # shape n x nth x ntw x d
270-
top_indexes = on_fluxes.argsort(dim=3, descending=True)
289+
on_nmgy = self.on_nmgy[..., band] # shape n x nth x ntw x d
290+
top_indexes = on_nmgy.argsort(dim=3, descending=True)
271291

272292
d = {"n_sources": self["n_sources"]}
273293
for key, val in self.items():
@@ -323,8 +343,8 @@ def filter_by_flux(self, min_flux=0, max_flux=torch.inf, band=2):
323343
sorted_self = self._sort_sources_by_flux(band=band)
324344

325345
# get fluxes of "on" sources to mask by
326-
on_fluxes = sorted_self.on_fluxes[..., band]
327-
flux_mask = (on_fluxes > min_flux) & (on_fluxes < max_flux)
346+
on_nmgy = sorted_self.on_nmgy[..., band]
347+
flux_mask = (on_nmgy > min_flux) & (on_nmgy < max_flux)
328348

329349
d = {}
330350
for key, val in sorted_self.items():
@@ -458,13 +478,19 @@ def galaxy_bools(self) -> Tensor:
458478
assert is_galaxy.size(2) == 1
459479
return is_galaxy * self.is_on_mask.unsqueeze(2)
460480

461-
@property
462-
def on_fluxes(self) -> Tensor:
463-
"""Gets fluxes of "on" sources based on whether the source is a star or galaxy.
481+
def on_fluxes(self, unit: str):
482+
match unit:
483+
case "nmgy":
484+
return self.on_nmgy
485+
case "mag":
486+
return self.on_mag
487+
case "njymag":
488+
return self.on_njymag
489+
case _:
490+
raise NotImplementedError()
464491

465-
Returns:
466-
Tensor: a tensor of fluxes of size (b x nth x ntw x max_sources x 1)
467-
"""
492+
@property
493+
def on_nmgy(self) -> Tensor:
468494
# ideally we'd always store fluxes rather than star_fluxes and galaxy_fluxes
469495
if "galaxy_fluxes" not in self:
470496
fluxes = self["star_fluxes"]
@@ -473,8 +499,12 @@ def on_fluxes(self) -> Tensor:
473499
return torch.where(self.is_on_mask[..., None], fluxes, torch.zeros_like(fluxes))
474500

475501
@property
476-
def magnitudes(self):
477-
return convert_nmgy_to_mag(self.on_fluxes)
502+
def on_mag(self) -> Tensor:
503+
return convert_nmgy_to_mag(self.on_nmgy)
504+
505+
@property
506+
def on_njymag(self) -> Tensor:
507+
return convert_nmgy_to_njymag(self.on_nmgy)
478508

479509
def one_source(self, b: int, s: int):
480510
"""Return a dict containing all parameter for one specified light source."""

bliss/conf/base_config.yaml

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,64 +80,75 @@ variational_factors:
8080
name: locs
8181
sample_rearrange: "b ht wt d -> b ht wt 1 d"
8282
nll_rearrange: "b ht wt 1 d -> b ht wt d"
83-
nll_gating: n_sources
83+
nll_gating:
84+
_target_: bliss.encoder.variational_dist.SourcesGating
8485
- _target_: bliss.encoder.variational_dist.BernoulliFactor
8586
name: source_type
8687
sample_rearrange: "b ht wt -> b ht wt 1 1"
8788
nll_rearrange: "b ht wt 1 1 -> b ht wt"
88-
nll_gating: n_sources
89+
nll_gating:
90+
_target_: bliss.encoder.variational_dist.SourcesGating
8991
- _target_: bliss.encoder.variational_dist.LogNormalFactor
9092
name: star_fluxes
9193
dim: 5
9294
sample_rearrange: "b ht wt d -> b ht wt 1 d"
9395
nll_rearrange: "b ht wt 1 d -> b ht wt d"
94-
nll_gating: is_star
96+
nll_gating:
97+
_target_: bliss.encoder.variational_dist.StarGating
9598
- _target_: bliss.encoder.variational_dist.LogNormalFactor
9699
name: galaxy_fluxes
97100
dim: 5
98101
sample_rearrange: "b ht wt d -> b ht wt 1 d"
99102
nll_rearrange: "b ht wt 1 d -> b ht wt d"
100-
nll_gating: is_galaxy
103+
nll_gating:
104+
_target_: bliss.encoder.variational_dist.GalaxyGating
101105
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
102106
name: galaxy_disk_frac
103107
sample_rearrange: "b ht wt d -> b ht wt 1 d"
104108
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
105-
nll_gating: is_galaxy
109+
nll_gating:
110+
_target_: bliss.encoder.variational_dist.GalaxyGating
106111
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
107112
name: galaxy_beta_radians
108113
high: 3.1415926
109114
sample_rearrange: "b ht wt d -> b ht wt 1 d"
110115
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
111-
nll_gating: is_galaxy
116+
nll_gating:
117+
_target_: bliss.encoder.variational_dist.GalaxyGating
112118
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
113119
name: galaxy_disk_q
114120
sample_rearrange: "b ht wt d -> b ht wt 1 d"
115121
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
116-
nll_gating: is_galaxy
122+
nll_gating:
123+
_target_: bliss.encoder.variational_dist.GalaxyGating
117124
- _target_: bliss.encoder.variational_dist.LogNormalFactor
118125
name: galaxy_a_d
119126
sample_rearrange: "b ht wt d -> b ht wt 1 d"
120127
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
121-
nll_gating: is_galaxy
128+
nll_gating:
129+
_target_: bliss.encoder.variational_dist.GalaxyGating
122130
- _target_: bliss.encoder.variational_dist.LogitNormalFactor
123131
name: galaxy_bulge_q
124132
sample_rearrange: "b ht wt d -> b ht wt 1 d"
125133
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
126-
nll_gating: is_galaxy
134+
nll_gating:
135+
_target_: bliss.encoder.variational_dist.GalaxyGating
127136
- _target_: bliss.encoder.variational_dist.LogNormalFactor
128137
name: galaxy_a_b
129138
sample_rearrange: "b ht wt d -> b ht wt 1 d"
130139
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
131-
nll_gating: is_galaxy
140+
nll_gating:
141+
_target_: bliss.encoder.variational_dist.GalaxyGating
132142

133143
metrics:
134144
detection_performance:
135145
_target_: bliss.encoder.metrics.DetectionPerformance
136-
mag_bin_cutoffs: [19, 19.4, 19.8, 20.2, 20.6, 21, 21.4, 21.8]
137-
mag_band: 2
146+
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
147+
bin_type: "njymag"
138148
source_type_accuracy:
139149
_target_: bliss.encoder.metrics.SourceTypeAccuracy
140-
flux_bin_cutoffs: [200, 400, 600, 800, 1000]
150+
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
151+
bin_type: "njymag"
141152
flux_error:
142153
_target_: bliss.encoder.metrics.FluxError
143154
survey_bands: ${encoder.survey_bands}
@@ -230,7 +241,7 @@ surveys:
230241
n_image_split: 50
231242
tile_slen: 4
232243
max_sources_per_tile: 5
233-
min_flux: 0.0
244+
catalog_min_r_flux: 50
234245
prepare_data_processes_num: 4
235246
data_in_one_cached_file: 1250
236247
splits: 0:80/80:90/90:100
@@ -242,7 +253,13 @@ surveys:
242253
- _target_: bliss.data_augmentation.RandomShiftTransform
243254
tile_slen: ${surveys.dc2.tile_slen}
244255
max_sources_per_tile: ${surveys.dc2.tile_slen}
245-
nontrain_transforms: []
256+
- _target_: bliss.cached_dataset.FluxFilterTransform
257+
reference_band: 2 # r-band
258+
min_flux: 100
259+
nontrain_transforms:
260+
- _target_: bliss.cached_dataset.FluxFilterTransform
261+
reference_band: 2 # r-band
262+
min_flux: 100
246263

247264

248265
#######################################################################

bliss/encoder/encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
self,
2727
survey_bands: list,
2828
tile_slen: int,
29-
image_normalizers: list,
29+
image_normalizers: dict,
3030
var_dist: VariationalDist,
3131
matcher: CatalogMatcher,
3232
sample_image_renders: MetricCollection,
@@ -105,7 +105,7 @@ def make_context(self, history_cat, history_mask, detection2=False):
105105
)
106106
else:
107107
centered_locs = history_cat["locs"][..., 0, :] - 0.5
108-
log_fluxes = (history_cat.on_fluxes.squeeze(3).sum(-1) + 1).log()
108+
log_fluxes = (history_cat.on_nmgy.squeeze(3).sum(-1) + 1).log()
109109
history_encoding_lst = [
110110
history_cat["n_sources"].float(), # detection history
111111
log_fluxes * history_cat["n_sources"], # flux history

0 commit comments

Comments
 (0)