Skip to content

Commit 42d059c

Browse files
linting
1 parent e4a8a4c commit 42d059c

File tree

5 files changed

+56
-38
lines changed

5 files changed

+56
-38
lines changed

case_studies/redshift/evaluation/evaluate_cts.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from hydra.utils import instantiate
88
from omegaconf import DictConfig
99
from tqdm import tqdm
10-
from hydra import compose, initialize
11-
from case_studies.redshift.redshift_from_img.encoder.encoder import RedshiftsEncoder
1210

1311

1412
def get_best_ckpt(ckpt_dir: str):
@@ -21,26 +19,23 @@ def get_best_ckpt(ckpt_dir: str):
2119
raise FileExistsError("No ckpt files found in the directory")
2220

2321

24-
# @hydra.main(config_path=".", config_name="continuous_eval")
22+
@hydra.main(config_path=".", config_name="continuous_eval")
2523
def main(cfg: DictConfig):
26-
with initialize(config_path=".", job_name="continuous_eval"):
27-
cfg = compose(config_name="continuous_eval")
28-
2924
output_dir = cfg.paths.plot_dir
3025
ckpt_dir = cfg.paths.ckpt_dir
3126

3227
output_dir = Path(output_dir)
3328
output_dir.mkdir(parents=True, exist_ok=True)
3429

3530
ckpt_path = get_best_ckpt(ckpt_dir)
36-
device = torch.device('cpu')
31+
device = torch.device("cpu")
3732

3833
# set up testing dataset
3934
dataset = instantiate(cfg.train.data_source)
4035
dataset.setup("test")
4136

4237
# load bliss trained model - continuous version
43-
bliss_encoder: RedshiftsEncoder = instantiate(cfg.encoder).to(device=device)
38+
bliss_encoder = instantiate(cfg.encoder).to(device=device)
4439
pretrained_weights = torch.load(ckpt_path, device)["state_dict"]
4540
bliss_encoder.load_state_dict(pretrained_weights)
4641
bliss_encoder.eval()
@@ -52,7 +47,6 @@ def main(cfg: DictConfig):
5247
if not bliss_output_path.exists():
5348
test_loader = dataset.test_dataloader()
5449
for batch_idx, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
55-
batch["images"] = batch["images"].to(device)
5650
bliss_encoder.update_metrics(batch, batch_idx)
5751
bliss_out_dict = bliss_encoder.mode_metrics.compute()
5852

case_studies/redshift/evaluation/evaluate_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def main(cfg: DictConfig):
2727
output_dir.mkdir(parents=True, exist_ok=True)
2828

2929
ckpt_path = get_best_ckpt(ckpt_dir)
30-
device = torch.device("cpu") # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30+
device = torch.device("cpu")
3131

3232
# set up testing dataset
3333
dataset = instantiate(cfg.train.data_source)

case_studies/redshift/evaluation/plots_bliss.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
from matplotlib import pyplot as plt
77
from matplotlib.ticker import FormatStrFormatter
88
from omegaconf import DictConfig
9-
from hydra import compose, initialize
9+
1010

1111
@hydra.main(config_path=".", config_name="discrete_eval")
1212
def main(cfg: DictConfig):
13-
with initialize(config_path=".", job_name="discrete_eval"):
14-
cfg = compose(config_name="discrete_eval")
1513
output_dir = Path(cfg.paths.plot_dir)
1614

1715
# Load metric results

case_studies/redshift/redshift_from_img/encoder/metrics.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@
1010
from bliss.encoder.metrics import CatalogMatcher
1111

1212

13+
def convert_nmgy_to_njymag(nmgy):
14+
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
15+
16+
For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
17+
the "Flux units: maggies and nanomaggies" part of
18+
https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
19+
When we change the standard source to AB sources, we need to do the conversion
20+
described in "2.10 AB magnitudes" at
21+
https://pstn-001.lsst.io/fluxunits.pdf
22+
23+
Args:
24+
nmgy: the fluxes in nanomaggies
25+
26+
Returns:
27+
Tensor indicating fluxes in AB magnitude
28+
"""
29+
30+
return 22.5 - 2.5 * torch.log10(nmgy / 3631)
31+
32+
1333
class MetricBin(Metric):
1434
def __init__(
1535
self,
@@ -67,6 +87,7 @@ def __init__(
6787

6888
def update(self, true_cat, est_cat, matching):
6989
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
90+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
7091
for i in range(true_cat.batch_size):
7192
tcat_matches, ecat_matches = matching[i]
7293

@@ -80,7 +101,7 @@ def update(self, true_cat, est_cat, matching):
80101
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
81102
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
82103

83-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
104+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
84105
bin_indices = torch.bucketize(true_mag, cutoffs)
85106

86107
red_err = (true_red - est_red).abs() ** 2
@@ -238,6 +259,7 @@ def __init__(
238259

239260
def update(self, true_cat, est_cat, matching):
240261
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
262+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
241263
for i in range(true_cat.batch_size):
242264
tcat_matches, ecat_matches = matching[i]
243265

@@ -251,7 +273,7 @@ def update(self, true_cat, est_cat, matching):
251273
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
252274
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
253275

254-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
276+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
255277
bin_indices = torch.bucketize(true_mag, cutoffs)
256278

257279
metric_outlier = torch.abs(true_red - est_red) / (1 + true_red)
@@ -322,6 +344,7 @@ def __init__(
322344

323345
def update(self, true_cat, est_cat, matching):
324346
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
347+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
325348
for i in range(true_cat.batch_size):
326349
tcat_matches, ecat_matches = matching[i]
327350

@@ -335,7 +358,7 @@ def update(self, true_cat, est_cat, matching):
335358
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
336359
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
337360

338-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
361+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
339362
bin_indices = torch.bucketize(true_mag, cutoffs)
340363

341364
metric_outlier_cata = torch.abs(true_red - est_red)
@@ -399,6 +422,7 @@ def __init__(self, **kwargs):
399422

400423
def update(self, true_cat, est_cat, matching):
401424
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
425+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
402426
for i in range(true_cat.batch_size):
403427
tcat_matches, ecat_matches = matching[i]
404428

@@ -412,7 +436,7 @@ def update(self, true_cat, est_cat, matching):
412436
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
413437
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
414438

415-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
439+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
416440
bin_indices = torch.bucketize(true_mag, cutoffs)
417441

418442
metrics = (true_red - est_red) / (1 + true_red)
@@ -492,6 +516,7 @@ def __init__(self, **kwargs):
492516

493517
def update(self, true_cat, est_cat, matching):
494518
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
519+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
495520
for i in range(true_cat.batch_size):
496521
tcat_matches, ecat_matches = matching[i]
497522

@@ -505,7 +530,7 @@ def update(self, true_cat, est_cat, matching):
505530
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
506531
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
507532

508-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
533+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
509534
bin_indices = torch.bucketize(true_mag, cutoffs)
510535

511536
metrics = (true_red - est_red) / (1 + true_red)
@@ -577,6 +602,7 @@ def __init__(self, **kwargs):
577602

578603
def update(self, true_cat, est_cat, matching):
579604
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
605+
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
580606
for i in range(true_cat.batch_size):
581607
tcat_matches, ecat_matches = matching[i]
582608

@@ -590,7 +616,7 @@ def update(self, true_cat, est_cat, matching):
590616
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
591617
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)
592618

593-
true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
619+
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
594620
bin_indices = torch.bucketize(true_mag, cutoffs)
595621

596622
metrics = torch.abs(true_red - est_red) / (1 + true_red)

case_studies/redshift/runner.sh

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,32 @@ export MKL_NUM_THREADS="16"
66
export NUMEXPR_NUM_THREADS="16"
77

88
# Produce data artifacts
9-
# echo "producing data artifacts for BLISS and RAIL from DC2"
10-
# python artifacts/data_generation.py
9+
echo "producing data artifacts for BLISS and RAIL from DC2"
10+
python artifacts/data_generation.py
1111

12-
# # Run BLISS (discrete variational distribution)
13-
# DIRNAME="$OUT_DIR/discrete"
12+
# Run BLISS (discrete variational distribution)
13+
DIRNAME="$OUT_DIR/discrete"
1414

15-
# if [ ! -d "$DIRNAME" ]; then
16-
# mkdir -p "$DIRNAME"
17-
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
18-
# else
19-
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
20-
# fi
15+
if [ ! -d "$DIRNAME" ]; then
16+
mkdir -p "$DIRNAME"
17+
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
18+
else
19+
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
20+
fi
2121

22-
# nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn discrete > "$DIRNAME/output.out" 2>&1 &
22+
nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn discrete > "$DIRNAME/output.out" 2>&1 &
2323

2424
# Run BLISS (continuous variational distribution)
25-
# DIRNAME="$OUT_DIR/continuous"
25+
DIRNAME="$OUT_DIR/continuous"
2626

27-
# if [ ! -d "$DIRNAME" ]; then
28-
# mkdir -p "$DIRNAME"
29-
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
30-
# else
31-
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
32-
# fi
27+
if [ ! -d "$DIRNAME" ]; then
28+
mkdir -p "$DIRNAME"
29+
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
30+
else
31+
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
32+
fi
3333

34-
# nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn continuous > "$DIRNAME/output.out" 2>&1 &
34+
nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn continuous > "$DIRNAME/output.out" 2>&1 &
3535

3636
# # Run RAIL
3737
# # TODO

0 commit comments

Comments
 (0)