Skip to content

Commit

Permalink
dc2 weak lensing (#1048)
Browse files Browse the repository at this point in the history
* added changes for plotting extraction to case_studies, plotcollection class, updated paths and configs

* added preliminary dc2 data exploration (week 1 and 2)

* attempted refactor of plots to use metrics instead of plotcollection, modified yaml files to support change, added null check before plotting in encoder

* updated base_config to reflect correct path to plots

* Refactor lensing plots

* updated encoder and plots to reflect req changes

* updated name of sample_image_renders to be more descriptive re function

* renamed plots to sample_image_renders and updated yaml to reflect

* updated encoder and sample_images to add support for weak_lensing plots (minor changes), refactored weak_lensing plots (major change)

* update plot_detections path in region_encoder.py

* Delete case_studies/dc2/dc2_shear_conv.ipynb (refactor and move to weak_lensing later on)

* Fix list/dict error in lensing config

* updated lensing config to match modifications

* initial commit of unfinished lensing map EDA nb

* updated weak lensing EDA notebook

* updated sample_image_renders to match changes in PR#1007

* Move DC2 lensing maps to notebooks dir

* Notebook with two-point corr for DC2

* Updated two-point notebook

* added dc2 datagen notebook for weak lensing

* Use DC2 redshift information in two-point notebook

* partial changes and restored full dc2, up to date with master; testing encoder and lensing_dc2

* dc2 for weak lensing naive architecture with subclassed dc2, convnet, encoder, modified metrics, have to make some changes to pass commit tests

* updated convnet and encoder to match smaller image size

* attempted fix of nans in backward pass

* fixed lensing dc2 to remove split files, made corresponding changes in cached dataset, updated lensing convnet to restore to old features net  and fixed lensing plots

* save plots changes

* updated to 1 downsampling

* added scale matching for plots

* updated lensing_config

* updated lensing config

* images without background

* Remove unnecessary .view in to_full_catalog

* Update base_config.yaml

* Update base_config.yaml

* pylint fix for lensing metrics

* Fix shear/convergence shapes in lensing MSE

* Add image-level means as baseline estimator for shear and convergence

* fixed total in lensing metrics to pass check

* updated encoder and vardist to roll back tilecatalog changes, replace with basetilecatalog

* removed duplicate methods in lensing encoder

* restored base_config from master

* updated lensing_dc2 to factor in changes to dc2

* added simple null normalizer to image normalizer (will remove in subsequent commit after fixing asinh), updated lensing config paths, and modified masking operation in lensing_dc2

* updated files to reflect recent changes and PR comments

* removed try/catch

* updated lensing_dc2 to match recent changes to dc2

---------

Co-authored-by: Tim White <[email protected]>
Co-authored-by: Shreyas Chandrashekaran <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent a77531b commit e857c6f
Show file tree
Hide file tree
Showing 11 changed files with 2,283 additions and 548 deletions.
8 changes: 8 additions & 0 deletions bliss/encoder/image_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ def get_input_tensor(self, batch):
# asinh seems to saturate beyond 5 or so
scaled_images = centered_images * (5.0 / quantiles5d.abs().clamp(1e-6))
return torch.asinh(scaled_images)


class NullNormalizer(torch.nn.Module):
def num_channels_per_band(self):
return 1

def get_input_tensor(self, batch):
return rearrange((batch["images"] + 0.5).clamp(1e-6) * 100, "b bands h w -> b bands 1 h w")
6 changes: 3 additions & 3 deletions bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TransformedDistribution,
)

from bliss.catalog import TileCatalog
from bliss.catalog import BaseTileCatalog, TileCatalog


class VariationalDist(torch.nn.Module):
Expand All @@ -32,10 +32,10 @@ def _factor_param_pairs(self, x_cat):
dist_params_lst = torch.split(x_cat, split_sizes, 3)
return zip(self.factors, dist_params_lst)

def sample(self, x_cat, use_mode=False):
def sample(self, x_cat, use_mode=False, return_base_cat=False):
fp_pairs = self._factor_param_pairs(x_cat)
d = {qk.name: qk.sample(params, use_mode) for qk, params in fp_pairs}
return TileCatalog(d)
return BaseTileCatalog(d) if return_base_cat else TileCatalog(d)

def compute_nll(self, x_cat, true_tile_cat):
fp_pairs = self._factor_param_pairs(x_cat)
Expand Down
127 changes: 72 additions & 55 deletions case_studies/weak_lensing/lensing_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,78 +4,95 @@ defaults:
- _self_
- override hydra/job_logging: stdout

mode: train

paths:
dc2: /data/scratch/dc2local # change for gl
output: /data/scratch/shreyasc/bliss_output # change for gl

variational_factors:
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
name: shear
sample_rearrange: "1 ht wt d -> ht wt 1 d"
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
nll_gating: null
- _target_: bliss.encoder.variational_dist.NormalFactor
name: convergence
sample_rearrange: "b ht wt -> b ht wt 1 1"
nll_rearrange: "b ht wt 1 1 -> b ht wt"
nll_gating: null
high_clamp: 0.
low_clamp: 0.

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
n_tiles_h: 48
n_tiles_w: 48
tile_slen: 4
batch_size: 32
prob_galaxy: 1.0
mean_sources: 0.2
arcsec_per_pixel: 0.396
sample_method: cosmology
shear_mean: 0
shear_std: 0.02
convergence_mean: 0
convergence_std: 0.02
num_knots: 2
my_normalizers:
# asinh:
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
# stride: 4
nully:
_target_: bliss.encoder.image_normalizer.NullNormalizer

simulator:
_target_: case_studies.weak_lensing.lensing_simulated_dataset.LensingSimulatedDataset
my_metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE

generate:
n_image_files: 8
n_batches_per_file: 16
cached_data_path: /data/scratch/shreyas/weak_lensing_cosmology_prior
my_render:
lensing_shear_conv:
_target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: "convergence_only_maps"

cached_simulator:
cached_data_path: /data/scratch/shreyas/weak_lensing_cosmology_prior
batch_size: 24

encoder:
metrics:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
survey_bands: ["u", "g", "r", "i", "z", "y"]
reference_band: 2 # r-band
tile_slen: 256
optimizer_params:
lr: 1e-2
scheduler_params:
milestones: [32]
gamma: 0.1
image_normalizers: ${my_normalizers}

var_dist:
_target_: bliss.encoder.variational_dist.VariationalDist
tile_slen: ${encoder.tile_slen}
factors: ${variational_factors}
mode_metrics:
_target_: torchmetrics.MetricCollection
_convert_: partial
metrics: ${my_metrics}
sample_metrics:
_target_: torchmetrics.MetricCollection
metrics:
lensing_map:
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
_convert_: partial
metrics: ${my_metrics}
sample_image_renders:
_target_: torchmetrics.MetricCollection
metrics:
- _target_: bliss.encoder.sample_image_renders.PlotSampleImages
frequency: 1
restrict_batch: 0
tiles_to_crop: 0
tile_slen: ${simulator.decoder.tile_slen}
- _target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
frequency: 1
restrict_batch: 0
tiles_to_crop: 0
tile_slen: ${simulator.decoder.tile_slen}
_convert_: partial
metrics: ${my_render}
use_double_detect: false
use_checkerboard: false
train_loss_location: "train_loss_plt"

surveys:
sdss:
_target_: bliss.surveys.sdss.SloanDigitalSkySurvey
dir_path: ${paths.sdss}
fields: # TODO: better arbitary name for fields/bricks?
- run: 2334
camcol: 6
fields: [13]
psf_config:
pixel_scale: 0.396
psf_slen: 25
align_to_band: null
load_image_data: true
dc2:
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
dc2_cat_path: ${paths.dc2}/lensing_catalog.pkl
image_slen: 2048
tile_slen: 256
splits: 0:80/80:90/90:100
batch_size: 1
num_workers: 1
cached_data_path: ${paths.output}/dc2_2048_galid_full_scaled_up

train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_08_05
devices: [6] # cuda:0 for gl
use_distributed_sampler: false
precision: 32-true
data_source: ${surveys.dc2}
pretrained_weights: null
seed: 123123
71 changes: 71 additions & 0 deletions case_studies/weak_lensing/lensing_convnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from torch import nn

from bliss.encoder.convnet_layers import C3, ConvBlock, Detect


class WeakLensingFeaturesNet(nn.Module):
def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
super().__init__()

nch_hidden = 64
self.preprocess3d = nn.Sequential(
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]),
nn.GroupNorm(
num_groups=32, num_channels=nch_hidden
), # sqrt of num channels, get rid of it, even shallower
nn.SiLU(),
)

# TODO: adaptive downsample
self.n_downsample = 1

module_list = []

for _ in range(self.n_downsample):
module_list.append(
ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2, padding=2)
)
nch_hidden *= 2

module_list.extend(
[
ConvBlock(nch_hidden, 64, kernel_size=5, padding=2),
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5, padding=2) for _ in range(1)]),
ConvBlock(64, 128, stride=2),
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]),
ConvBlock(128, num_features, stride=1),
]
) # 4

self.net = nn.ModuleList(module_list)

def forward(self, x):
x = self.preprocess3d(x).squeeze(2)
for _i, m in enumerate(self.net):
x = m(x)

return x


class WeakLensingCatalogNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

net_layers = [
C3(in_channels, 256, n=1, shortcut=True), # 0
ConvBlock(256, 512, stride=2),
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection
ConvBlock(
in_channels=256, out_channels=256, kernel_size=3, stride=8, padding=1
), # (1, 256, 128, 128)
ConvBlock(
in_channels=256, out_channels=256, kernel_size=3, stride=4, padding=1
), # (1, 256, 8, 8)
Detect(256, out_channels),
]
self.net = nn.ModuleList(net_layers)

def forward(self, x):
for _i, m in enumerate(self.net):
x = m(x)
return x
Loading

0 comments on commit e857c6f

Please sign in to comment.