Skip to content

Commit

Permalink
Reorganize and update BLISS for photo-z code (#1080)
Browse files Browse the repository at this point in the history
* trying to merge into main

* linting

* linting

* linting

* linting

* linting

* linting

* linting

* passing tests

* don't modify bliss.surveys.dc2

* linting

* linting

* linting

* linting

* moving discrete metrics into RedshiftsEncoder

* linting

* don't alter bliss/encoder/variational_dist.py

* ditto

* moving all discrete 1d redshfit stuff to case_studies

* linting

* minimal PR, all in case_studies. Should pass tests
  • Loading branch information
declanmcnamara authored Feb 4, 2025
1 parent 73276aa commit b2b4fef
Show file tree
Hide file tree
Showing 59 changed files with 13,790 additions and 26,810 deletions.
10 changes: 10 additions & 0 deletions case_studies/redshift/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
training_runs/
runs/
rail
plot
format
DC2_redshift/
DC2_redshift_predict_output/
DC2_redshift_training/
DC2_split_result_test/
DC2output
24 changes: 24 additions & 0 deletions case_studies/redshift/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
This redshift estimation project is consist of 4 parts:
1. Estimate photo-z using neural network (training data is GT mag and redshift)
2. Estimate photo-z using bliss directly from image.
3. Estimate photo-z using lsst + rail pipeline (model from LSST)
4. Estimate photo-z using lsst + pretrained neural network from 1.

There are a few things need to do to make sure you can do evaluation(make plottings) on these four parts(I suggest you first go to the fourth part to see if you miss some key parts for evaluation)

1. Training neural network
Skip this step if you already have pretrained network. My pretrained network is saved at /data/scratch/qiaozhih/DC2_redshift_training/DC2_redshift_only_bin_allmetrics/checkpoints/encoder_0.182845.ckpt
`
./preprocess_dataset.sh
./train.sh
`

2. Train bliss
run /home/qiaozhih/bliss/case_studies/redshift/redshift_from_img/train.sh
You can modify config at /home/qiaozhih/bliss/case_studies/redshift/redshift_from_img/full_train_config_redshift.yaml

3. Train rail
All training code can be found at /home/qiaozhih/bliss/case_studies/redshift/evaluation/rail/RAIL_estimation_demo.ipynb. Make sure you install rail from and you must make sure you are using the corresponding env from rail instead of the bliss.

4. Evaluate & make plot
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb
87 changes: 87 additions & 0 deletions case_studies/redshift/artifacts/data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
from pathlib import Path

import GCRCatalogs
import hydra
import numpy as np
import pandas as pd
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

from case_studies.redshift.artifacts.redshift_dc2 import DC2DataModule

logging.basicConfig(level=logging.INFO)


# ------------- RAIL ------------ #
def create_rail_artifacts(rail_cfg: DictConfig):
"""Create DataFrames of ugrizy magnitudes and errors for RAIL training."""
logging.info("Creating RAIL artifacts at %s", rail_cfg.processed_data_dir)
log_dir = Path(rail_cfg.processed_data_dir)

# Create output directory if it does not exist, or skip if artifacts already exist
if not log_dir.exists():
log_dir.mkdir(parents=True, exist_ok=True)
elif not rail_cfg.pipeline.force_reprocess:
logging.info("RAIL artifacts already exist. Skipping creation.")
return

lsst_root_dir = rail_cfg.pipeline.lsst_root_dir
GCRCatalogs.set_root_dir(lsst_root_dir)
lsst_catalog_gcr = GCRCatalogs.load_catalog(rail_cfg.pipeline.truth_match_catalog)
lsst_catalog_subset = lsst_catalog_gcr.get_quantities(list(rail_cfg.pipeline.quantities))
lsst_catalog_df = pd.DataFrame(lsst_catalog_subset)

# Drop rows with inf or NaN
lsst_catalog_df_na = lsst_catalog_df.replace([np.inf, -np.inf], np.nan)
lsst_catalog_df_nona = lsst_catalog_df_na.dropna()

# Rename some columns
new_name = {
"id_truth": "id",
"mag_u_cModel": "mag_u_lsst",
"mag_g_cModel": "mag_g_lsst",
"mag_r_cModel": "mag_r_lsst",
"mag_i_cModel": "mag_i_lsst",
"mag_z_cModel": "mag_z_lsst",
"mag_y_cModel": "mag_y_lsst",
"magerr_u_cModel": "mag_err_u_lsst",
"magerr_g_cModel": "mag_err_g_lsst",
"magerr_r_cModel": "mag_err_r_lsst",
"magerr_i_cModel": "mag_err_i_lsst",
"magerr_z_cModel": "mag_err_z_lsst",
"magerr_y_cModel": "mag_err_y_lsst",
}

lsst_catalog_df_nona_newname = lsst_catalog_df_nona.rename(new_name, axis=1)

# Save pickle for RAIL-based training.
train_nrow = rail_cfg.pipeline.train_size
val_nrow = rail_cfg.pipeline.val_size
lsst_catalog_df_nona_newname[:train_nrow].to_pickle(log_dir / "lsst_train_nona_200k.pkl")
lsst_catalog_df_nona_newname[-1 - val_nrow - 1 :].to_pickle(log_dir / "lsst_val_nona_100k.pkl")


# ------------- BLISS ----------- #
def create_bliss_artifacts(bliss_cfg: DictConfig):
"""CONSTRUCT BATCHES (.pt files) FOR DATA LOADING."""
logging.info("Creating BLISS artifacts at %s", bliss_cfg.paths.processed_data_dir_bliss)
dc2: DC2DataModule = instantiate(bliss_cfg.surveys.dc2)
dc2.prepare_data()


@hydra.main(config_path="../", config_name="redshift")
def main(cfg: DictConfig) -> None:
logging.info("Starting data generation")
logging.info(OmegaConf.to_yaml(cfg))
# Create RAIL artifacts
create_rail_artifacts(cfg.rail)

# Create BLISS artifacts
create_bliss_artifacts(cfg)

logging.info("Data generation complete")


if __name__ == "__main__":
main()
156 changes: 156 additions & 0 deletions case_studies/redshift/artifacts/redshift_dc2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# pylint: disable=R0801
import logging
import multiprocessing
import pathlib
from pathlib import Path

import torch

from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, split_tensor, unpack_dict


class RedshiftDC2DataModule(DC2DataModule):
BANDS = ("u", "g", "r", "i", "z", "y")

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

self.dc2_image_dir = Path(self.dc2_image_dir)
self.dc2_cat_path = Path(self.dc2_cat_path)
self._tract_patches = None

def _load_image_and_bg_files_list(self):
img_pattern = "**/*/calexp*.fits"
bg_pattern = "**/*/bkgd*.fits"
image_files = []
bg_files = []

for band in self.bands:
band_path = self.dc2_image_dir / str(band)
img_file_list = list(pathlib.Path(band_path).glob(img_pattern))
bg_file_list = list(pathlib.Path(band_path).glob(bg_pattern))

image_files.append(sorted(img_file_list))
bg_files.append(sorted(bg_file_list))
n_image = len(bg_files[0])

# assign state only in main process
self._image_files = image_files
self._bg_files = bg_files

# record which tracts and patches
tracts = [str(file_name).split("/")[-3] for file_name in self._image_files[0]]
patches = [
str(file_name).rsplit("-", maxsplit=1)[-1][:3] for file_name in self._image_files[0]
] # TODO: check
self._tract_patches = [x[0] + "_" + x[1] for x in zip(tracts, patches)] # TODO: hack

return n_image

def prepare_data(self): # noqa: WPS324
if self.cached_data_path.exists():
logger = logging.getLogger("DC2DataModule")
warning_msg = "WARNING: cached data already exists at [%s], we directly use it\n"
logger.warning(warning_msg, str(self.cached_data_path))
return None

logger = logging.getLogger("DC2DataModule")
warning_msg = "WARNING: can't find cached data, we generate it at [%s]\n"
logger.warning(warning_msg, str(self.cached_data_path))
if not self.cached_data_path.exists():
self.cached_data_path.mkdir(parents=True)

n_image = self._load_image_and_bg_files_list()

# Train
if self.prepare_data_processes_num > 1:
with multiprocessing.Pool(processes=self.prepare_data_processes_num) as process_pool:
process_pool.map(
self.generate_cached_data,
zip(list(range(n_image)), self._tract_patches),
chunksize=4,
)
else:
for i in range(n_image):
self.generate_cached_data((i, self._tract_patches[i]))

return None

def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R0801
image_index, patch_name = naming_info
result_dict = self.load_image_and_catalog(image_index)

image = result_dict["inputs"]["image"]
tile_dict = result_dict["tile_dict"]
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
psf_params = result_dict["inputs"]["psf_params"]

# split image
split_lim = self.image_lim[0] // self.n_image_split
image_splits = split_tensor(image, split_lim, 1, 2)
image_width_pixels = image.shape[2]
split_image_num_on_width = image_width_pixels // split_lim

# split tile cat
tile_cat_splits = {}
param_list = [
"locs",
"n_sources",
"source_type",
"galaxy_fluxes",
"star_fluxes",
"redshifts",
"blendedness",
"shear",
"ellipticity",
"cosmodc2_mask",
"one_source_mask",
"two_sources_mask",
"more_than_two_sources_mask",
]
for param_name in param_list:
tile_cat_splits[param_name] = split_tensor(
tile_dict[param_name], split_lim // self.tile_slen, 0, 1
)

objid = split_tensor(tile_dict["objid"], split_lim // self.tile_slen, 0, 1)

data_splits = {
"tile_catalog": unpack_dict(tile_cat_splits),
"images": image_splits,
"image_height_index": (
torch.arange(0, len(image_splits)) // split_image_num_on_width
).tolist(),
"image_width_index": (
torch.arange(0, len(image_splits)) % split_image_num_on_width
).tolist(),
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
"objid": objid,
}
data_splits = split_list(
unpack_dict(data_splits),
sub_list_len=self.data_in_one_cached_file,
)

data_count = 0
for sub_splits in data_splits: # noqa: WPS426
tmp_data_cached = []
for split in sub_splits: # noqa: WPS426
split_clone = map_nested_dicts(
split, lambda x: x.clone() if isinstance(x, torch.Tensor) else x
)
split_clone.update(wcs_header_str=wcs_header_str)
tmp_data_cached.append(split_clone)
assert data_count < 1e5 and image_index < 1e5, "too many cached data files"
assert len(tmp_data_cached) < 1e5, "too many cached data in one file"
cached_data_file_name = (
f"cached_data_{patch_name}_{data_count:04d}_size_{len(tmp_data_cached):04d}.pt"
)
cached_data_file_path = self.cached_data_path / cached_data_file_name
with open(cached_data_file_path, "wb") as cached_data_file:
torch.save(tmp_data_cached, cached_data_file)
data_count += 1
15 changes: 15 additions & 0 deletions case_studies/redshift/evaluation/continuous_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
defaults:
- ../redshift_from_img@_here_: continuous
- _self_
# - override hydra/job_logging: stdout

paths:
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_cts_results/checkpoints
plot_dir: ${paths.data_dir}/plots

# To reduce memory usage
surveys:
dc2:
batch_size: 4
max_sources_per_tile: 5
Loading

0 comments on commit b2b4fef

Please sign in to comment.