Skip to content

Commit

Permalink
fix on_magnitudes
Browse files Browse the repository at this point in the history
  • Loading branch information
Yicun Duan committed Aug 19, 2024
1 parent 4a3f13e commit ac52a33
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions case_studies/redshift/redshift_from_img/encoder/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)
Expand Down Expand Up @@ -254,7 +254,7 @@ def update(self, true_cat, est_cat, matching):
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

if self.bin_type == "ab_mag":
true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
elif self.bin_type == "njy":
Expand Down Expand Up @@ -343,7 +343,7 @@ def update(self, true_cat, est_cat, matching):
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

if self.bin_type == "ab_mag":
true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
elif self.bin_type == "njy":
Expand Down Expand Up @@ -425,7 +425,7 @@ def update(self, true_cat, est_cat, matching):
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

if self.bin_type == "ab_mag":
true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
elif self.bin_type == "njy":
Expand Down Expand Up @@ -523,7 +523,7 @@ def update(self, true_cat, est_cat, matching):
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

if self.bin_type == "ab_mag":
true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
elif self.bin_type == "njy":
Expand Down Expand Up @@ -613,7 +613,7 @@ def update(self, true_cat, est_cat, matching):
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

if self.bin_type == "ab_mag":
true_mag = true_cat.on_magnitudes(zero_point=3631)[i][..., self.mag_band][
true_mag = true_cat.on_magnitudes(zero_point=3631e9)[i][..., self.mag_band][
tcat_matches
].to(self.device)
elif self.bin_type == "njy":
Expand Down

0 comments on commit ac52a33

Please sign in to comment.