-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorganize and update BLISS for photo-z code (#1080)
* trying to merge into main * linting * linting * linting * linting * linting * linting * linting * passing tests * don't modify bliss.surveys.dc2 * linting * linting * linting * linting * moving discrete metrics into RedshiftsEncoder * linting * don't alter bliss/encoder/variational_dist.py * ditto * moving all discrete 1d redshfit stuff to case_studies * linting * minimal PR, all in case_studies. Should pass tests
- Loading branch information
1 parent
73276aa
commit b2b4fef
Showing
59 changed files
with
13,790 additions
and
26,810 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
training_runs/ | ||
runs/ | ||
rail | ||
plot | ||
format | ||
DC2_redshift/ | ||
DC2_redshift_predict_output/ | ||
DC2_redshift_training/ | ||
DC2_split_result_test/ | ||
DC2output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
This redshift estimation project is consist of 4 parts: | ||
1. Estimate photo-z using neural network (training data is GT mag and redshift) | ||
2. Estimate photo-z using bliss directly from image. | ||
3. Estimate photo-z using lsst + rail pipeline (model from LSST) | ||
4. Estimate photo-z using lsst + pretrained neural network from 1. | ||
|
||
There are a few things need to do to make sure you can do evaluation(make plottings) on these four parts(I suggest you first go to the fourth part to see if you miss some key parts for evaluation) | ||
|
||
1. Training neural network | ||
Skip this step if you already have pretrained network. My pretrained network is saved at /data/scratch/qiaozhih/DC2_redshift_training/DC2_redshift_only_bin_allmetrics/checkpoints/encoder_0.182845.ckpt | ||
` | ||
./preprocess_dataset.sh | ||
./train.sh | ||
` | ||
|
||
2. Train bliss | ||
run /home/qiaozhih/bliss/case_studies/redshift/redshift_from_img/train.sh | ||
You can modify config at /home/qiaozhih/bliss/case_studies/redshift/redshift_from_img/full_train_config_redshift.yaml | ||
|
||
3. Train rail | ||
All training code can be found at /home/qiaozhih/bliss/case_studies/redshift/evaluation/rail/RAIL_estimation_demo.ipynb. Make sure you install rail from and you must make sure you are using the corresponding env from rail instead of the bliss. | ||
|
||
4. Evaluate & make plot | ||
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
import GCRCatalogs | ||
import hydra | ||
import numpy as np | ||
import pandas as pd | ||
from hydra.utils import instantiate | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from case_studies.redshift.artifacts.redshift_dc2 import DC2DataModule | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
# ------------- RAIL ------------ # | ||
def create_rail_artifacts(rail_cfg: DictConfig): | ||
"""Create DataFrames of ugrizy magnitudes and errors for RAIL training.""" | ||
logging.info("Creating RAIL artifacts at %s", rail_cfg.processed_data_dir) | ||
log_dir = Path(rail_cfg.processed_data_dir) | ||
|
||
# Create output directory if it does not exist, or skip if artifacts already exist | ||
if not log_dir.exists(): | ||
log_dir.mkdir(parents=True, exist_ok=True) | ||
elif not rail_cfg.pipeline.force_reprocess: | ||
logging.info("RAIL artifacts already exist. Skipping creation.") | ||
return | ||
|
||
lsst_root_dir = rail_cfg.pipeline.lsst_root_dir | ||
GCRCatalogs.set_root_dir(lsst_root_dir) | ||
lsst_catalog_gcr = GCRCatalogs.load_catalog(rail_cfg.pipeline.truth_match_catalog) | ||
lsst_catalog_subset = lsst_catalog_gcr.get_quantities(list(rail_cfg.pipeline.quantities)) | ||
lsst_catalog_df = pd.DataFrame(lsst_catalog_subset) | ||
|
||
# Drop rows with inf or NaN | ||
lsst_catalog_df_na = lsst_catalog_df.replace([np.inf, -np.inf], np.nan) | ||
lsst_catalog_df_nona = lsst_catalog_df_na.dropna() | ||
|
||
# Rename some columns | ||
new_name = { | ||
"id_truth": "id", | ||
"mag_u_cModel": "mag_u_lsst", | ||
"mag_g_cModel": "mag_g_lsst", | ||
"mag_r_cModel": "mag_r_lsst", | ||
"mag_i_cModel": "mag_i_lsst", | ||
"mag_z_cModel": "mag_z_lsst", | ||
"mag_y_cModel": "mag_y_lsst", | ||
"magerr_u_cModel": "mag_err_u_lsst", | ||
"magerr_g_cModel": "mag_err_g_lsst", | ||
"magerr_r_cModel": "mag_err_r_lsst", | ||
"magerr_i_cModel": "mag_err_i_lsst", | ||
"magerr_z_cModel": "mag_err_z_lsst", | ||
"magerr_y_cModel": "mag_err_y_lsst", | ||
} | ||
|
||
lsst_catalog_df_nona_newname = lsst_catalog_df_nona.rename(new_name, axis=1) | ||
|
||
# Save pickle for RAIL-based training. | ||
train_nrow = rail_cfg.pipeline.train_size | ||
val_nrow = rail_cfg.pipeline.val_size | ||
lsst_catalog_df_nona_newname[:train_nrow].to_pickle(log_dir / "lsst_train_nona_200k.pkl") | ||
lsst_catalog_df_nona_newname[-1 - val_nrow - 1 :].to_pickle(log_dir / "lsst_val_nona_100k.pkl") | ||
|
||
|
||
# ------------- BLISS ----------- # | ||
def create_bliss_artifacts(bliss_cfg: DictConfig): | ||
"""CONSTRUCT BATCHES (.pt files) FOR DATA LOADING.""" | ||
logging.info("Creating BLISS artifacts at %s", bliss_cfg.paths.processed_data_dir_bliss) | ||
dc2: DC2DataModule = instantiate(bliss_cfg.surveys.dc2) | ||
dc2.prepare_data() | ||
|
||
|
||
@hydra.main(config_path="../", config_name="redshift") | ||
def main(cfg: DictConfig) -> None: | ||
logging.info("Starting data generation") | ||
logging.info(OmegaConf.to_yaml(cfg)) | ||
# Create RAIL artifacts | ||
create_rail_artifacts(cfg.rail) | ||
|
||
# Create BLISS artifacts | ||
create_bliss_artifacts(cfg) | ||
|
||
logging.info("Data generation complete") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# pylint: disable=R0801 | ||
import logging | ||
import multiprocessing | ||
import pathlib | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, split_tensor, unpack_dict | ||
|
||
|
||
class RedshiftDC2DataModule(DC2DataModule): | ||
BANDS = ("u", "g", "r", "i", "z", "y") | ||
|
||
def __init__( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.dc2_image_dir = Path(self.dc2_image_dir) | ||
self.dc2_cat_path = Path(self.dc2_cat_path) | ||
self._tract_patches = None | ||
|
||
def _load_image_and_bg_files_list(self): | ||
img_pattern = "**/*/calexp*.fits" | ||
bg_pattern = "**/*/bkgd*.fits" | ||
image_files = [] | ||
bg_files = [] | ||
|
||
for band in self.bands: | ||
band_path = self.dc2_image_dir / str(band) | ||
img_file_list = list(pathlib.Path(band_path).glob(img_pattern)) | ||
bg_file_list = list(pathlib.Path(band_path).glob(bg_pattern)) | ||
|
||
image_files.append(sorted(img_file_list)) | ||
bg_files.append(sorted(bg_file_list)) | ||
n_image = len(bg_files[0]) | ||
|
||
# assign state only in main process | ||
self._image_files = image_files | ||
self._bg_files = bg_files | ||
|
||
# record which tracts and patches | ||
tracts = [str(file_name).split("/")[-3] for file_name in self._image_files[0]] | ||
patches = [ | ||
str(file_name).rsplit("-", maxsplit=1)[-1][:3] for file_name in self._image_files[0] | ||
] # TODO: check | ||
self._tract_patches = [x[0] + "_" + x[1] for x in zip(tracts, patches)] # TODO: hack | ||
|
||
return n_image | ||
|
||
def prepare_data(self): # noqa: WPS324 | ||
if self.cached_data_path.exists(): | ||
logger = logging.getLogger("DC2DataModule") | ||
warning_msg = "WARNING: cached data already exists at [%s], we directly use it\n" | ||
logger.warning(warning_msg, str(self.cached_data_path)) | ||
return None | ||
|
||
logger = logging.getLogger("DC2DataModule") | ||
warning_msg = "WARNING: can't find cached data, we generate it at [%s]\n" | ||
logger.warning(warning_msg, str(self.cached_data_path)) | ||
if not self.cached_data_path.exists(): | ||
self.cached_data_path.mkdir(parents=True) | ||
|
||
n_image = self._load_image_and_bg_files_list() | ||
|
||
# Train | ||
if self.prepare_data_processes_num > 1: | ||
with multiprocessing.Pool(processes=self.prepare_data_processes_num) as process_pool: | ||
process_pool.map( | ||
self.generate_cached_data, | ||
zip(list(range(n_image)), self._tract_patches), | ||
chunksize=4, | ||
) | ||
else: | ||
for i in range(n_image): | ||
self.generate_cached_data((i, self._tract_patches[i])) | ||
|
||
return None | ||
|
||
def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R0801 | ||
image_index, patch_name = naming_info | ||
result_dict = self.load_image_and_catalog(image_index) | ||
|
||
image = result_dict["inputs"]["image"] | ||
tile_dict = result_dict["tile_dict"] | ||
wcs_header_str = result_dict["other_info"]["wcs_header_str"] | ||
psf_params = result_dict["inputs"]["psf_params"] | ||
|
||
# split image | ||
split_lim = self.image_lim[0] // self.n_image_split | ||
image_splits = split_tensor(image, split_lim, 1, 2) | ||
image_width_pixels = image.shape[2] | ||
split_image_num_on_width = image_width_pixels // split_lim | ||
|
||
# split tile cat | ||
tile_cat_splits = {} | ||
param_list = [ | ||
"locs", | ||
"n_sources", | ||
"source_type", | ||
"galaxy_fluxes", | ||
"star_fluxes", | ||
"redshifts", | ||
"blendedness", | ||
"shear", | ||
"ellipticity", | ||
"cosmodc2_mask", | ||
"one_source_mask", | ||
"two_sources_mask", | ||
"more_than_two_sources_mask", | ||
] | ||
for param_name in param_list: | ||
tile_cat_splits[param_name] = split_tensor( | ||
tile_dict[param_name], split_lim // self.tile_slen, 0, 1 | ||
) | ||
|
||
objid = split_tensor(tile_dict["objid"], split_lim // self.tile_slen, 0, 1) | ||
|
||
data_splits = { | ||
"tile_catalog": unpack_dict(tile_cat_splits), | ||
"images": image_splits, | ||
"image_height_index": ( | ||
torch.arange(0, len(image_splits)) // split_image_num_on_width | ||
).tolist(), | ||
"image_width_index": ( | ||
torch.arange(0, len(image_splits)) % split_image_num_on_width | ||
).tolist(), | ||
"psf_params": [psf_params for _ in range(self.n_image_split**2)], | ||
"objid": objid, | ||
} | ||
data_splits = split_list( | ||
unpack_dict(data_splits), | ||
sub_list_len=self.data_in_one_cached_file, | ||
) | ||
|
||
data_count = 0 | ||
for sub_splits in data_splits: # noqa: WPS426 | ||
tmp_data_cached = [] | ||
for split in sub_splits: # noqa: WPS426 | ||
split_clone = map_nested_dicts( | ||
split, lambda x: x.clone() if isinstance(x, torch.Tensor) else x | ||
) | ||
split_clone.update(wcs_header_str=wcs_header_str) | ||
tmp_data_cached.append(split_clone) | ||
assert data_count < 1e5 and image_index < 1e5, "too many cached data files" | ||
assert len(tmp_data_cached) < 1e5, "too many cached data in one file" | ||
cached_data_file_name = ( | ||
f"cached_data_{patch_name}_{data_count:04d}_size_{len(tmp_data_cached):04d}.pt" | ||
) | ||
cached_data_file_path = self.cached_data_path / cached_data_file_name | ||
with open(cached_data_file_path, "wb") as cached_data_file: | ||
torch.save(tmp_data_cached, cached_data_file) | ||
data_count += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
--- | ||
defaults: | ||
- ../redshift_from_img@_here_: continuous | ||
- _self_ | ||
# - override hydra/job_logging: stdout | ||
|
||
paths: | ||
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_cts_results/checkpoints | ||
plot_dir: ${paths.data_dir}/plots | ||
|
||
# To reduce memory usage | ||
surveys: | ||
dc2: | ||
batch_size: 4 | ||
max_sources_per_tile: 5 |
Oops, something went wrong.