Skip to content

Commit

Permalink
Qh/redshift evaluation (#1051)
Browse files Browse the repository at this point in the history
* declan

* dc2 bliss pred redshift can run

rebase master to branch

* can predict redshift

rebase

* modify

rebase

* redshifts modify

* add redshift training and func of testing using best encoder

* add redshift training and func of testing using best encoder

rebase

* tidy code

* best encoder for test

* can run 0-dist code

* remove no need file

* rm

* rm var

* rm est

* format

* make same as master

* change config

* format

* pred redshift only

* bliss pred redshift

* dc2 mag formula change

* add nanojy mag

* can run lsst prediction and evaluation

* update code for all metrics evaluation on LSST and Bliss

* redshift plots stratified by true redshift and blendedness

* rm unnecessary files and code modification for PR

* rm unnecessary utils

---------

Co-authored-by: Qiaozhi Huang <[email protected]>
  • Loading branch information
georgeyfly and Qiaozhi Huang authored Aug 2, 2024
1 parent 6898c24 commit c0645a3
Show file tree
Hide file tree
Showing 9 changed files with 26,592 additions and 499 deletions.
1 change: 0 additions & 1 deletion bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
class VariationalDist(torch.nn.Module):
def __init__(self, factors, tile_slen):
super().__init__()

self.factors = factors
self.tile_slen = tile_slen

Expand Down
26,208 changes: 25,860 additions & 348 deletions case_studies/redshift/evaluation/dc2_plot.ipynb

Large diffs are not rendered by default.

150 changes: 119 additions & 31 deletions case_studies/redshift/evaluation/notebook_plot.yaml
Original file line number Diff line number Diff line change
@@ -1,35 +1,57 @@
---
defaults:
- ../redshift_from_img@_here_: full_train_config_redshift
- ../../../bliss/conf@_here_: base_config
- _self_

paths:
root: /home/qiaozhih/bliss

global_setting:
min_flux_for_loss: 50 # you need to regenerate split_results after changing this number

variational_factors:
- _target_: bliss.encoder.variational_dist.NormalFactor
name: redshifts
sample_rearrange: "b ht wt -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt"
nll_gating: is_galaxy

encoder:
_target_: case_studies.redshift.redshift_from_img.encoder.encoder.RedshiftsEncoder
# _target_: bliss.encoder.encoder.Encoder
survey_bands: ["g", "i", "r", "u", "y", "z"]
tile_slen: 4
tiles_to_crop: 1
min_flux_for_loss: ${global_setting.min_flux_for_loss}
min_flux_for_metrics: 100
optimizer_params:
lr: 1e-3
scheduler_params:
milestones: [32]
gamma: 0.1
image_normalizer:
_target_: bliss.encoder.image_normalizer.ImageNormalizer
bands: [0, 1, 2, 3, 4, 5]
include_original: false
include_background: false
concat_psf_params: false
num_psf_params: 4 # for SDSS, 4 for DC2
log_transform_stdevs: []
use_clahe: true
clahe_min_stdev: 200
# matcher:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftsCatalogMatcher
# match_gating: is_galaxy
image_normalizers:
psf:
_target_: bliss.encoder.image_normalizer.PsfAsImage
num_psf_params: 4 # 6 for SDSS, 4 for DC2, 10 for DES
clahe:
_target_: bliss.encoder.image_normalizer.ClaheNormalizer
min_stdev: 200
asinh:
_target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
# matcher:
# _target_: bliss.encoder.metrics.CatalogMatcher
# dist_slack: 1.0
# mag_slack: null
# mag_band: 2 # SDSS r-band
matcher:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftsCatalogMatcher
match_gating: is_galaxy
metrics:
mode_metrics:
_target_: torchmetrics.MetricCollection
_convert_: "partial"
metrics: ${my_metrics_test}
sample_metrics:
_target_: torchmetrics.MetricCollection
_convert_: "partial"
metrics: ${my_metrics_test}
Expand All @@ -40,22 +62,88 @@ encoder:
use_checkerboard: false

my_metrics_test:
redshift_mearn_square_error_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
mag_bin_cutoffs: [200, 400, 600, 800, 1000]
redshift_outlier_fraction_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFractionBin
mag_bin_cutoffs: [200, 400, 600, 800, 1000]
redshift_nmad_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftNormalizedMedianAbsDevBin
mag_bin_cutoffs: [200, 400, 600, 800, 1000]
redshift_outlier_fraction_catastrophic_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFractionCataBin
mag_bin_cutoffs: [200, 400, 600, 800, 1000]
redshift_bias_bin:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftBiasBin
mag_bin_cutoffs: [200, 400, 600, 800, 1000]
# redshift_mearn_square_error:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredError
# redshift_mearn_square_error_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"
# redshift_mean_square_error_blendedness:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorBlendedness
# bin_cutoffs: [0.0001, 0.02, 0.1, 0.2, 0.6]
redshift_mean_square_error_true_redshifts:
_target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftMeanSquaredErrorTrueRedshift
bin_cutoffs: [0.5, 1, 1.5, 2, 2.5, 3]
# redshift_outlier_fraction:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFraction
# redshift_outlier_fraction_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFractionBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"
# redshift_nmad:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftNormalizedMedianAbsDev
# redshift_nmad_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftNormalizedMedianAbsDevBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"
# redshift_outlier_fraction_catastrophic:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFractionCata
# redshift_outlier_fraction_catastrophic_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftOutlierFractionCataBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"
# redshift_bias:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftBias
# redshift_bias_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftBiasBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"
# redshift_abs_bias:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftAbsBias
# redshift_abs_bias_bin:
# _target_: case_studies.redshift.redshift_from_img.encoder.metrics.RedshiftAbsBiasBin
# bin_cutoffs: [23.9, 24.1, 24.5, 24.9, 25.6]
# bin_type: "njymag"

train:
trainer:
logger:
name: DC2_redshift_training
version: DC2_redshift_only_large_split_blend
save_dir: ${paths.root}/case_studies/redshift/redshift_from_img/
precision: 32
strategy:
_target_: pytorch_lightning.strategies.DDPStrategy
find_unused_parameters: true
process_group_backend: nccl
timeout:
_target_: datetime.timedelta
seconds: 180000
val_check_interval: 0.5
# check_val_every_n_epoch: 1
# devices: [0, 2, 3, 4]
devices: [1]
max_epochs: 50
callbacks:
checkpointing:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
filename: encoder_{val/mode/redshifts/mse:.6f}
save_top_k: 5
verbose: True
# monitor: val/_loss
monitor: val/mode/redshifts/mse
mode: min
early_stopping:
_target_: pytorch_lightning.callbacks.early_stopping.EarlyStopping
monitor: val/_loss
mode: min
patience: 500
data_source: ${surveys.dc2}
pretrained_weights: null

surveys:
dc2:
batch_size: 64
dc2_cat_path: /data/scratch/dc2local/merged_catalog_with_flux_over_50.pkl
cached_data_path: /data/scratch/dc2local/dc2_cached_data
batch_size: 4
max_sources_per_tile: 5
Loading

0 comments on commit c0645a3

Please sign in to comment.