Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
declanmcnamara committed Feb 10, 2025
1 parent e4a8a4c commit 42d059c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 38 deletions.
12 changes: 3 additions & 9 deletions case_studies/redshift/evaluation/evaluate_cts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from hydra.utils import instantiate
from omegaconf import DictConfig
from tqdm import tqdm
from hydra import compose, initialize
from case_studies.redshift.redshift_from_img.encoder.encoder import RedshiftsEncoder


def get_best_ckpt(ckpt_dir: str):
Expand All @@ -21,26 +19,23 @@ def get_best_ckpt(ckpt_dir: str):
raise FileExistsError("No ckpt files found in the directory")


# @hydra.main(config_path=".", config_name="continuous_eval")
@hydra.main(config_path=".", config_name="continuous_eval")
def main(cfg: DictConfig):
with initialize(config_path=".", job_name="continuous_eval"):
cfg = compose(config_name="continuous_eval")

output_dir = cfg.paths.plot_dir
ckpt_dir = cfg.paths.ckpt_dir

output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

ckpt_path = get_best_ckpt(ckpt_dir)
device = torch.device('cpu')
device = torch.device("cpu")

# set up testing dataset
dataset = instantiate(cfg.train.data_source)
dataset.setup("test")

# load bliss trained model - continuous version
bliss_encoder: RedshiftsEncoder = instantiate(cfg.encoder).to(device=device)
bliss_encoder = instantiate(cfg.encoder).to(device=device)
pretrained_weights = torch.load(ckpt_path, device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval()
Expand All @@ -52,7 +47,6 @@ def main(cfg: DictConfig):
if not bliss_output_path.exists():
test_loader = dataset.test_dataloader()
for batch_idx, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
batch["images"] = batch["images"].to(device)
bliss_encoder.update_metrics(batch, batch_idx)
bliss_out_dict = bliss_encoder.mode_metrics.compute()

Expand Down
2 changes: 1 addition & 1 deletion case_studies/redshift/evaluation/evaluate_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(cfg: DictConfig):
output_dir.mkdir(parents=True, exist_ok=True)

ckpt_path = get_best_ckpt(ckpt_dir)
device = torch.device("cpu") # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# set up testing dataset
dataset = instantiate(cfg.train.data_source)
Expand Down
4 changes: 1 addition & 3 deletions case_studies/redshift/evaluation/plots_bliss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from matplotlib import pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from omegaconf import DictConfig
from hydra import compose, initialize


@hydra.main(config_path=".", config_name="discrete_eval")
def main(cfg: DictConfig):
with initialize(config_path=".", job_name="discrete_eval"):
cfg = compose(config_name="discrete_eval")
output_dir = Path(cfg.paths.plot_dir)

# Load metric results
Expand Down
38 changes: 32 additions & 6 deletions case_studies/redshift/redshift_from_img/encoder/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
from bliss.encoder.metrics import CatalogMatcher


def convert_nmgy_to_njymag(nmgy):
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
the "Flux units: maggies and nanomaggies" part of
https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
When we change the standard source to AB sources, we need to do the conversion
described in "2.10 AB magnitudes" at
https://pstn-001.lsst.io/fluxunits.pdf
Args:
nmgy: the fluxes in nanomaggies
Returns:
Tensor indicating fluxes in AB magnitude
"""

return 22.5 - 2.5 * torch.log10(nmgy / 3631)


class MetricBin(Metric):
def __init__(
self,
Expand Down Expand Up @@ -67,6 +87,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -80,7 +101,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

red_err = (true_red - est_red).abs() ** 2
Expand Down Expand Up @@ -238,6 +259,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -251,7 +273,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metric_outlier = torch.abs(true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -322,6 +344,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -335,7 +358,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metric_outlier_cata = torch.abs(true_red - est_red)
Expand Down Expand Up @@ -399,6 +422,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -412,7 +436,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = (true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -492,6 +516,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -505,7 +530,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = (true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -577,6 +602,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -590,7 +616,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = torch.abs(true_red - est_red) / (1 + true_red)
Expand Down
38 changes: 19 additions & 19 deletions case_studies/redshift/runner.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@ export MKL_NUM_THREADS="16"
export NUMEXPR_NUM_THREADS="16"

# Produce data artifacts
# echo "producing data artifacts for BLISS and RAIL from DC2"
# python artifacts/data_generation.py
echo "producing data artifacts for BLISS and RAIL from DC2"
python artifacts/data_generation.py

# # Run BLISS (discrete variational distribution)
# DIRNAME="$OUT_DIR/discrete"
# Run BLISS (discrete variational distribution)
DIRNAME="$OUT_DIR/discrete"

# if [ ! -d "$DIRNAME" ]; then
# mkdir -p "$DIRNAME"
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
# else
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
# fi
if [ ! -d "$DIRNAME" ]; then
mkdir -p "$DIRNAME"
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
else
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
fi

# nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn discrete > "$DIRNAME/output.out" 2>&1 &
nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn discrete > "$DIRNAME/output.out" 2>&1 &

# Run BLISS (continuous variational distribution)
# DIRNAME="$OUT_DIR/continuous"
DIRNAME="$OUT_DIR/continuous"

# if [ ! -d "$DIRNAME" ]; then
# mkdir -p "$DIRNAME"
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
# else
# echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
# fi
if [ ! -d "$DIRNAME" ]; then
mkdir -p "$DIRNAME"
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
else
echo "BLISS training logs/checkpoints will be saved to $DIRNAME"
fi

# nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn continuous > "$DIRNAME/output.out" 2>&1 &
nohup python bliss/main.py -cp ~/bliss/case_studies/redshift/redshift_from_img -cn continuous > "$DIRNAME/output.out" 2>&1 &

# # Run RAIL
# # TODO
Expand Down

0 comments on commit 42d059c

Please sign in to comment.