Skip to content

Commit 1c9384d

Browse files
Final tweaks to get redshift code up to date with master. (#1081)
* tweaks * made some changes to fit with master; now need to update RedshiftsEncoder * getting the correct shapes from sampling for RedshiftsEncoder. Samples coming through. About to change method of to remove self.bin_type indexing of . Previously on_fluxes must have been non-tensor * at this commit, stuff was running. Still want to run all the way probably though to see that plots still look good * linting * restore other case studies * something is up with the binning; checking it out inside the metric classes * linting
1 parent b2b4fef commit 1c9384d

File tree

14 files changed

+108
-102
lines changed

14 files changed

+108
-102
lines changed

case_studies/redshift/README.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,24 @@
1-
This redshift estimation project is consist of 4 parts:
1+
# BLISS-PZ - BLISS For Photo-z Prediction
2+
3+
#### Running BLISS-PZ on DC2
4+
5+
Modify the config file `redshift.yaml` as follows:
6+
1. Change `paths.data_dir` to a directory where you're happy to have all data artifacts and checkpoints stored.
7+
2. Make the `OUT_DIR` variable in `runner.sh` this same location.
8+
3. Modify `paths.dc2` to the location of `dc2` on your system.
9+
10+
To produce the results from `BLISS-PZ`, run `runner.sh` (you made need to make this an executable, `chmod +x runner.sh` from within this directory).
11+
12+
```
13+
./runner.sh
14+
```
15+
16+
The runner bash script launches programs sequentially: first data prep, then two different runs of BLISS, followed by RAIL. Thereafter, plots are produced. For your use case it may be better to run different parts of the runner script on their own. Take a look at the script and comment out the relevant parts if you need.
17+
18+
19+
20+
21+
<!-- This redshift estimation project is consist of 4 parts:
222
1. Estimate photo-z using neural network (training data is GT mag and redshift)
323
2. Estimate photo-z using bliss directly from image.
424
3. Estimate photo-z using lsst + rail pipeline (model from LSST)
@@ -21,4 +41,4 @@ You can modify config at /home/qiaozhih/bliss/case_studies/redshift/redshift_fro
2141
All training code can be found at /home/qiaozhih/bliss/case_studies/redshift/evaluation/rail/RAIL_estimation_demo.ipynb. Make sure you install rail from and you must make sure you are using the corresponding env from rail instead of the bliss.
2242
2343
4. Evaluate & make plot
24-
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb
44+
Run all the code at /home/qiaozhih/bliss/case_studies/redshift/evaluation/dc2_plot.ipynb -->

case_studies/redshift/artifacts/data_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from hydra.utils import instantiate
99
from omegaconf import DictConfig, OmegaConf
1010

11-
from case_studies.redshift.artifacts.redshift_dc2 import DC2DataModule
11+
from case_studies.redshift.artifacts.redshift_dc2 import RedshiftDC2DataModule
1212

1313
logging.basicConfig(level=logging.INFO)
1414

@@ -66,7 +66,7 @@ def create_rail_artifacts(rail_cfg: DictConfig):
6666
def create_bliss_artifacts(bliss_cfg: DictConfig):
6767
"""CONSTRUCT BATCHES (.pt files) FOR DATA LOADING."""
6868
logging.info("Creating BLISS artifacts at %s", bliss_cfg.paths.processed_data_dir_bliss)
69-
dc2: DC2DataModule = instantiate(bliss_cfg.surveys.dc2)
69+
dc2: RedshiftDC2DataModule = instantiate(bliss_cfg.surveys.dc2)
7070
dc2.prepare_data()
7171

7272

case_studies/redshift/artifacts/redshift_dc2.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, split_tensor, unpack_dict
9+
from bliss.surveys.dc2 import DC2DataModule, map_nested_dicts, split_list, unpack_dict
1010

1111

1212
class RedshiftDC2DataModule(DC2DataModule):
@@ -89,20 +89,11 @@ def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R08
8989
wcs_header_str = result_dict["other_info"]["wcs_header_str"]
9090
psf_params = result_dict["inputs"]["psf_params"]
9191

92-
# split image
93-
split_lim = self.image_lim[0] // self.n_image_split
94-
image_splits = split_tensor(image, split_lim, 1, 2)
95-
image_width_pixels = image.shape[2]
96-
split_image_num_on_width = image_width_pixels // split_lim
97-
98-
# split tile cat
99-
tile_cat_splits = {}
10092
param_list = [
10193
"locs",
10294
"n_sources",
10395
"source_type",
104-
"galaxy_fluxes",
105-
"star_fluxes",
96+
"fluxes",
10697
"redshifts",
10798
"blendedness",
10899
"shear",
@@ -112,27 +103,11 @@ def generate_cached_data(self, naming_info: tuple): # pylint: disable=W0237,R08
112103
"two_sources_mask",
113104
"more_than_two_sources_mask",
114105
]
115-
for param_name in param_list:
116-
tile_cat_splits[param_name] = split_tensor(
117-
tile_dict[param_name], split_lim // self.tile_slen, 0, 1
118-
)
119106

120-
objid = split_tensor(tile_dict["objid"], split_lim // self.tile_slen, 0, 1)
121-
122-
data_splits = {
123-
"tile_catalog": unpack_dict(tile_cat_splits),
124-
"images": image_splits,
125-
"image_height_index": (
126-
torch.arange(0, len(image_splits)) // split_image_num_on_width
127-
).tolist(),
128-
"image_width_index": (
129-
torch.arange(0, len(image_splits)) % split_image_num_on_width
130-
).tolist(),
131-
"psf_params": [psf_params for _ in range(self.n_image_split**2)],
132-
"objid": objid,
133-
}
107+
splits = self.split_image_and_tile_cat(image, tile_dict, param_list, psf_params)
108+
134109
data_splits = split_list(
135-
unpack_dict(data_splits),
110+
unpack_dict(splits),
136111
sub_list_len=self.data_in_one_cached_file,
137112
)
138113

case_studies/redshift/evaluation/continuous_eval.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defaults:
55
# - override hydra/job_logging: stdout
66

77
paths:
8-
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_cts_results/checkpoints
8+
ckpt_dir: ${paths.data_dir}/checkpoints/continuous/checkpoints
99
plot_dir: ${paths.data_dir}/plots
1010

1111
# To reduce memory usage

case_studies/redshift/evaluation/discrete_eval.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ defaults:
55
# - override hydra/job_logging: stdout
66

77
paths:
8-
ckpt_dir: /data/scratch/declan/redshift/dc2/BLISS_DC2_redshift_discrete_results/checkpoints
8+
ckpt_dir: ${paths.data_dir}/checkpoints/discrete/checkpoints
99
plot_dir: ${paths.data_dir}/plots
1010

1111
# To reduce memory usage

case_studies/redshift/evaluation/evaluate_cts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def main(cfg: DictConfig):
2828
output_dir.mkdir(parents=True, exist_ok=True)
2929

3030
ckpt_path = get_best_ckpt(ckpt_dir)
31-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31+
device = torch.device("cpu")
3232

3333
# set up testing dataset
3434
dataset = instantiate(cfg.train.data_source)
@@ -47,7 +47,6 @@ def main(cfg: DictConfig):
4747
if not bliss_output_path.exists():
4848
test_loader = dataset.test_dataloader()
4949
for batch_idx, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
50-
batch["images"] = batch["images"].to(device)
5150
bliss_encoder.update_metrics(batch, batch_idx)
5251
bliss_out_dict = bliss_encoder.mode_metrics.compute()
5352

case_studies/redshift/evaluation/evaluate_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def main(cfg: DictConfig):
2727
output_dir.mkdir(parents=True, exist_ok=True)
2828

2929
ckpt_path = get_best_ckpt(ckpt_dir)
30-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
30+
device = torch.device("cpu")
3131

3232
# set up testing dataset
3333
dataset = instantiate(cfg.train.data_source)

case_studies/redshift/redshift.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ paths:
1313
# Defaults from `base_config`. We modify cached path and use own class
1414
surveys:
1515
dc2:
16-
_target_: bliss.case_studies.redshift.artifacts.redshift_dc2.RedshiftDC2DataModule
16+
_target_: case_studies.redshift.artifacts.redshift_dc2.RedshiftDC2DataModule
1717
cached_data_path: ${paths.processed_data_dir_bliss}
1818
dc2_cat_path: /data/scratch/dc2local/merged_catalog_with_flux_over_50.pkl # we should have a script that makes this on our own
1919

case_studies/redshift/redshift_from_img/continuous.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ global_setting:
1313
variational_factors:
1414
- _target_: bliss.encoder.variational_dist.NormalFactor
1515
name: redshifts
16-
sample_rearrange: "b ht wt -> b ht wt 1 1"
17-
nll_rearrange: "b ht wt 1 1 -> b ht wt"
16+
sample_rearrange: "b ht wt 1 -> b ht wt 1 1"
17+
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
1818
nll_gating: is_galaxy
1919

2020
encoder:
@@ -57,7 +57,7 @@ encoder:
5757

5858
# Can optimize to these metrics by choosing bin carefully
5959
discrete_metrics:
60-
redshift_mearn_square_error_bin:
60+
redshift_mean_square_error_bin:
6161
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
6262
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
6363
bin_type: "njymag"
@@ -80,9 +80,9 @@ discrete_metrics:
8080

8181
# Standard metric computation
8282
mode_sample_metrics:
83-
redshift_mearn_square_error:
83+
redshift_mean_square_error:
8484
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredError
85-
redshift_mearn_square_error_bin:
85+
redshift_mean_square_error_bin:
8686
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
8787
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
8888
bin_type: "njymag"

case_studies/redshift/redshift_from_img/discrete.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ paths:
1212
root: /home/declan/bliss
1313

1414
variational_factors:
15-
- _target_: bliss.encoder.variational_dist.DiscretizedFactor1D
15+
- _target_: case_studies.redshift.redshift_from_img.encoder.variational_dist.DiscretizedFactor1D
1616
name: redshifts
1717
sample_rearrange: "b ht wt -> b ht wt 1 1"
1818
nll_rearrange: "b ht wt 1 1 -> b ht wt"
@@ -65,7 +65,7 @@ encoder:
6565

6666
# Can optimize to these metrics by choosing bin carefully
6767
discrete_metrics:
68-
redshift_mearn_square_error_bin:
68+
redshift_mean_square_error_bin:
6969
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
7070
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
7171
bin_type: "njymag"
@@ -88,9 +88,9 @@ discrete_metrics:
8888

8989
# Standard metric computation
9090
mode_sample_metrics:
91-
redshift_mearn_square_error:
91+
redshift_mean_square_error:
9292
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredError
93-
redshift_mearn_square_error_bin:
93+
redshift_mean_square_error_bin:
9494
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
9595
bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
9696
bin_type: "njymag"

0 commit comments

Comments
 (0)