Skip to content

Commit

Permalink
Final tweaks to get redshift code up to date with master. (#1081)
Browse files Browse the repository at this point in the history
* tweaks

* made some changes to fit with master; now need to update RedshiftsEncoder

* getting the correct shapes from sampling for RedshiftsEncoder. Samples coming through. About to change  method of  to remove self.bin_type indexing of . Previously on_fluxes must have been non-tensor

* at this commit, stuff was running. Still want to run all the way probably though to see that plots still look good

* linting

* restore other case studies

* something is up with the binning; checking it out inside the metric classes

* linting
  • Loading branch information
declanmcnamara authored Feb 10, 2025
1 parent b2b4fef commit 1c9384d
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 102 deletions.
24 changes: 22 additions & 2 deletions case_studies/redshift/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
This redshift estimation project is consist of 4 parts:
# BLISS-PZ - BLISS For Photo-z Prediction

#### Running BLISS-PZ on DC2

Modify the config file `redshift.yaml` as follows:
1. Change `paths.data_dir` to a directory where you're happy to have all data artifacts and checkpoints stored.
2. Make the `OUT_DIR` variable in `runner.sh` this same location.
3. Modify `paths.dc2` to the location of `dc2` on your system.

To produce the results from `BLISS-PZ`, run `runner.sh` (you made need to make this an executable, `chmod +x runner.sh` from within this directory).

```
./runner.sh
```

The runner bash script launches programs sequentially: first data prep, then two different runs of BLISS, followed by RAIL. Thereafter, plots are produced. For your use case it may be better to run different parts of the runner script on their own. Take a look at the script and comment out the relevant parts if you need.




<!-- This redshift estimation project is consist of 4 parts:
1. Estimate photo-z using neural network (training data is GT mag and redshift)
2. Estimate photo-z using bliss directly from image.
3. Estimate photo-z using lsst + rail pipeline (model from LSST)
Expand All @@ -21,4 +41,4 @@ You can modify config at /home/qiaozhih/bliss/case_studies/redshift/redshift_fro
All training code can be found at /home/qiaozhih/bliss/case_studies/redshift/evaluation/rail/RAIL_estimation_demo.ipynb. Make sure you install rail from and you must make sure you are using the corresponding env from rail instead of the bliss.
4. Evaluate & make plot
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb -->
4 changes: 2 additions & 2 deletions case_studies/redshift/artifacts/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from case_studies.redshift.artifacts.redshift_dc2 import DC2DataModule
from case_studies.redshift.artifacts.redshift_dc2 import RedshiftDC2DataModule

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -66,7 +66,7 @@ def create_rail_artifacts(rail_cfg: DictConfig):
def create_bliss_artifacts(bliss_cfg: DictConfig):
"""CONSTRUCT BATCHES (.pt files) FOR DATA LOADING."""
logging.info("Creating BLISS artifacts at %s", bliss_cfg.paths.processed_data_dir_bliss)
dc2: DC2DataModule = instantiate(bliss_cfg.surveys.dc2)
dc2: RedshiftDC2DataModule = instantiate(bliss_cfg.surveys.dc2)
dc2.prepare_data()


Expand Down
35 changes: 5 additions & 30 deletions case_studies/redshift/artifacts/redshift_dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, split_tensor, unpack_dict
from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, unpack_dict


class RedshiftDC2DataModule(DC2DataModule):
Expand Down Expand Up @@ -89,20 +89,11 @@ def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R08
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

# split image
split_lim = self.image_lim[0] // self.n_image_split
image_splits = split_tensor(image, split_lim, 1, 2)
image_width_pixels = image.shape[2]
split_image_num_on_width = image_width_pixels // split_lim

# split tile cat
tile_cat_splits = {}
param_list = [
"locs",
"n_sources",
"source_type",
"galaxy_fluxes",
"star_fluxes",
"fluxes",
"redshifts",
"blendedness",
"shear",
Expand All @@ -112,27 +103,11 @@ def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R08
"two_sources_mask",
"more_than_two_sources_mask",
]
for param_name in param_list:
tile_cat_splits[param_name] = split_tensor(
tile_dict[param_name], split_lim // self.tile_slen, 0, 1
)

objid = split_tensor(tile_dict["objid"], split_lim // self.tile_slen, 0, 1)

data_splits = {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
"objid": objid,
}
splits = self.split_image_and_tile_cat(image, tile_dict, param_list, psf_params)

data_splits = split_list(
unpack_dict(data_splits),
unpack_dict(splits),
sub_list_len=self.data_in_one_cached_file,
)

Expand Down
2 changes: 1 addition & 1 deletion case_studies/redshift/evaluation/continuous_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
# - override hydra/job_logging: stdout

paths:
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_cts_results/checkpoints
ckpt_dir: ${paths.data_dir}/checkpoints/continuous/checkpoints
plot_dir: ${paths.data_dir}/plots

# To reduce memory usage
Expand Down
2 changes: 1 addition & 1 deletion case_studies/redshift/evaluation/discrete_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
# - override hydra/job_logging: stdout

paths:
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_discrete_results/checkpoints
ckpt_dir: ${paths.data_dir}/checkpoints/discrete/checkpoints
plot_dir: ${paths.data_dir}/plots

# To reduce memory usage
Expand Down
3 changes: 1 addition & 2 deletions case_studies/redshift/evaluation/evaluate_cts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(cfg: DictConfig):
output_dir.mkdir(parents=True, exist_ok=True)

ckpt_path = get_best_ckpt(ckpt_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# set up testing dataset
dataset = instantiate(cfg.train.data_source)
Expand All @@ -47,7 +47,6 @@ def main(cfg: DictConfig):
if not bliss_output_path.exists():
test_loader = dataset.test_dataloader()
for batch_idx, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
batch["images"] = batch["images"].to(device)
bliss_encoder.update_metrics(batch, batch_idx)
bliss_out_dict = bliss_encoder.mode_metrics.compute()

Expand Down
2 changes: 1 addition & 1 deletion case_studies/redshift/evaluation/evaluate_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(cfg: DictConfig):
output_dir.mkdir(parents=True, exist_ok=True)

ckpt_path = get_best_ckpt(ckpt_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# set up testing dataset
dataset = instantiate(cfg.train.data_source)
Expand Down
2 changes: 1 addition & 1 deletion case_studies/redshift/redshift.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ paths:
# Defaults from `base_config`. We modify cached path and use own class
surveys:
dc2:
_target_: bliss.case_studies.redshift.artifacts.redshift_dc2.RedshiftDC2DataModule
_target_: case_studies.redshift.artifacts.redshift_dc2.RedshiftDC2DataModule
cached_data_path: ${paths.processed_data_dir_bliss}
dc2_cat_path: /data/scratch/dc2local/merged_catalog_with_flux_over_50.pkl # we should have a script that makes this on our own

Expand Down
10 changes: 5 additions & 5 deletions case_studies/redshift/redshift_from_img/continuous.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ global_setting:
variational_factors:
- _target_: bliss.encoder.variational_dist.NormalFactor
name: redshifts
sample_rearrange: "b ht wt -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt"
sample_rearrange: "b ht wt 1 -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: is_galaxy

encoder:
Expand Down Expand Up @@ -57,7 +57,7 @@ encoder:

# Can optimize to these metrics by choosing bin carefully
discrete_metrics:
redshift_mearn_square_error_bin:
redshift_mean_square_error_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
Expand All @@ -80,9 +80,9 @@ discrete_metrics:

# Standard metric computation
mode_sample_metrics:
redshift_mearn_square_error:
redshift_mean_square_error:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredError
redshift_mearn_square_error_bin:
redshift_mean_square_error_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
Expand Down
8 changes: 4 additions & 4 deletions case_studies/redshift/redshift_from_img/discrete.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ paths:
root: /home/declan/bliss

variational_factors:
- _target_: bliss.encoder.variational_dist.DiscretizedFactor1D
- _target_: case_studies.redshift.redshift_from_img.encoder.variational_dist.DiscretizedFactor1D
name: redshifts
sample_rearrange: "b ht wt -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt"
Expand Down Expand Up @@ -65,7 +65,7 @@ encoder:

# Can optimize to these metrics by choosing bin carefully
discrete_metrics:
redshift_mearn_square_error_bin:
redshift_mean_square_error_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
Expand All @@ -88,9 +88,9 @@ discrete_metrics:

# Standard metric computation
mode_sample_metrics:
redshift_mearn_square_error:
redshift_mean_square_error:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredError
redshift_mearn_square_error_bin:
redshift_mean_square_error_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
bin_type: "njymag"
Expand Down
22 changes: 10 additions & 12 deletions case_studies/redshift/redshift_from_img/encoder/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,17 @@ def get_features_and_parameters(self, batch):
if isinstance(batch, dict)
else {"images": batch, "background": torch.zeros_like(batch)}
)
batch_size, _n_bands, h, w = batch["images"].shape[0:4]
ht, wt = h // self.tile_slen, w // self.tile_slen
x_features = self.get_features(batch)
batch_size, _n_features, ht, wt = x_features.shape[0:4]
pattern_to_use = (0,) # no checkerboard
mask_pattern = self.mask_patterns[pattern_to_use, ...]
est_cat = None
history_mask = mask_pattern.repeat([batch_size, ht // 2, wt // 2])
x_color_context = self.make_color_context(est_cat, history_mask)
x_features_color = torch.cat((x_features, x_color_context), dim=1)
x_cat_marginal = self.detect_first(x_features_color)

input_lst = [
inorm.get_input_tensor(batch).to(batch["images"].device)
for inorm in self.image_normalizers
]
x = torch.cat(input_lst, dim=2)
x_features = self.features_net(x)
mask = torch.zeros([batch_size, ht, wt])
context = self.make_context(None, mask).to("cuda")
x_cat_marginal = self.catalog_net(x_features, context)
return x_features, x_cat_marginal
return x_features_color, x_cat_marginal

def sample(self, batch, use_mode=True):
_, x_cat_marginal = self.get_features_and_parameters(batch)
Expand Down
50 changes: 32 additions & 18 deletions case_studies/redshift/redshift_from_img/encoder/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
from bliss.encoder.metrics import CatalogMatcher


def convert_nmgy_to_njymag(nmgy):
"""Convert from flux (nano-maggie) to mag (nano-jansky), which is the format used by DC2.
For the difference between mag (Pogson magnitude) and njymag (AB magnitude), please view
the "Flux units: maggies and nanomaggies" part of
https://www.sdss3.org/dr8/algorithms/magnitudes.php#nmgy
When we change the standard source to AB sources, we need to do the conversion
described in "2.10 AB magnitudes" at
https://pstn-001.lsst.io/fluxunits.pdf
Args:
nmgy: the fluxes in nanomaggies
Returns:
Tensor indicating fluxes in AB magnitude
"""

return 22.5 - 2.5 * torch.log10(nmgy / 3631)


class MetricBin(Metric):
def __init__(
self,
Expand Down Expand Up @@ -67,6 +87,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -80,9 +101,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes("njymag")[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

red_err = (true_red - est_red).abs() ** 2
Expand Down Expand Up @@ -240,6 +259,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -253,9 +273,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes(self.bin_type)[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metric_outlier = torch.abs(true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -326,6 +344,7 @@ def __init__(

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -339,9 +358,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes(self.bin_type)[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metric_outlier_cata = torch.abs(true_red - est_red)
Expand Down Expand Up @@ -405,6 +422,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -418,9 +436,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes(self.bin_type)[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = (true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -500,6 +516,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -513,9 +530,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes(self.bin_type)[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = (true_red - est_red) / (1 + true_red)
Expand Down Expand Up @@ -587,6 +602,7 @@ def __init__(self, **kwargs):

def update(self, true_cat, est_cat, matching):
cutoffs = torch.tensor(self.bin_cutoffs, device=self.device)
on_fluxes = convert_nmgy_to_njymag(true_cat.on_fluxes)
for i in range(true_cat.batch_size):
tcat_matches, ecat_matches = matching[i]

Expand All @@ -600,9 +616,7 @@ def update(self, true_cat, est_cat, matching):
true_red = true_cat["redshifts"][i, tcat_matches, :].to(self.device)
est_red = est_cat["redshifts"][i, ecat_matches, :].to(self.device)

true_mag = true_cat.on_fluxes(self.bin_type)[i][..., self.mag_band][tcat_matches].to(
self.device
)
true_mag = on_fluxes[i][..., self.mag_band][tcat_matches].to(self.device)
bin_indices = torch.bucketize(true_mag, cutoffs)

metrics = torch.abs(true_red - est_red) / (1 + true_red)
Expand Down
Loading

0 comments on commit 1c9384d

Please sign in to comment.