Skip to content

Commit e857c6f

Browse files
shreyasc30timwhite0Shreyas Chandrashekaran
authored
dc2 weak lensing (#1048)
* added changes for plotting extraction to case_studies, plotcollection class, updated paths and configs * added preliminary dc2 data exploration (week 1 and 2) * attempted refactor of plots to use metrics instead of plotcollection, modified yaml files to support change, added null check before plotting in encoder * updated base_config to reflect correct path to plots * Refactor lensing plots * updated encoder and plots to reflect req changes * updated name of sample_image_renders to be more descriptive re function * renamed plots to sample_image_renders and updated yaml to reflect * updated encoder and sample_images to add support for weak_lensing plots (minor changes), refactored weak_lensing plots (major change) * update plot_detections path in region_encoder.py * Delete case_studies/dc2/dc2_shear_conv.ipynb (refactor and move to weak_lensing later on) * Fix list/dict error in lensing config * updated lensing config to match modifications * initial commit of unfinished lensing map EDA nb * updated weak lensing EDA notebook * updated sample_image_renders to match changes in PR#1007 * Move DC2 lensing maps to notebooks dir * Notebook with two-point corr for DC2 * Updated two-point notebook * added dc2 datagen notebook for weak lensing * Use DC2 redshift information in two-point notebook * partial changes and restored full dc2, up to date with master; testing encoder and lensing_dc2 * dc2 for weak lensing naive architecture with subclassed dc2, convnet, encoder, modified metrics, have to make some changes to pass commit tests * updated convnet and encoder to match smaller image size * attempted fix of nans in backward pass * fixed lensing dc2 to remove split files, made corresponding changes in cached dataset, updated lensing convnet to restore to old features net and fixed lensing plots * save plots changes * updated to 1 downsampling * added scale matching for plots * updated lensing_config * updated lensing config * images without background * Remove unnecessary .view in to_full_catalog * Update base_config.yaml * Update base_config.yaml * pylint fix for lensing metrics * Fix shear/convergence shapes in lensing MSE * Add image-level means as baseline estimator for shear and convergence * fixed total in lensing metrics to pass check * updated encoder and vardist to roll back tilecatalog changes, replace with basetilecatalog * removed duplicate methods in lensing encoder * restored base_config from master * updated lensing_dc2 to factor in changes to dc2 * added simple null normalizer to image normalizer (will remove in subsequent commit after fixing asinh), updated lensing config paths, and modified masking operation in lensing_dc2 * updated files to reflect recent changes and PR comments * removed try/catch * updated lensing_dc2 to match recent changes to dc2 --------- Co-authored-by: Tim White <[email protected]> Co-authored-by: Shreyas Chandrashekaran <[email protected]>
1 parent a77531b commit e857c6f

File tree

11 files changed

+2283
-548
lines changed

11 files changed

+2283
-548
lines changed

bliss/encoder/image_normalizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,11 @@ def get_input_tensor(self, batch):
8989
# asinh seems to saturate beyond 5 or so
9090
scaled_images = centered_images * (5.0 / quantiles5d.abs().clamp(1e-6))
9191
return torch.asinh(scaled_images)
92+
93+
94+
class NullNormalizer(torch.nn.Module):
95+
def num_channels_per_band(self):
96+
return 1
97+
98+
def get_input_tensor(self, batch):
99+
return rearrange((batch["images"] + 0.5).clamp(1e-6) * 100, "b bands h w -> b bands 1 h w")

bliss/encoder/variational_dist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
TransformedDistribution,
1515
)
1616

17-
from bliss.catalog import TileCatalog
17+
from bliss.catalog import BaseTileCatalog, TileCatalog
1818

1919

2020
class VariationalDist(torch.nn.Module):
@@ -32,10 +32,10 @@ def _factor_param_pairs(self, x_cat):
3232
dist_params_lst = torch.split(x_cat, split_sizes, 3)
3333
return zip(self.factors, dist_params_lst)
3434

35-
def sample(self, x_cat, use_mode=False):
35+
def sample(self, x_cat, use_mode=False, return_base_cat=False):
3636
fp_pairs = self._factor_param_pairs(x_cat)
3737
d = {qk.name: qk.sample(params, use_mode) for qk, params in fp_pairs}
38-
return TileCatalog(d)
38+
return BaseTileCatalog(d) if return_base_cat else TileCatalog(d)
3939

4040
def compute_nll(self, x_cat, true_tile_cat):
4141
fp_pairs = self._factor_param_pairs(x_cat)

case_studies/weak_lensing/lensing_config.yaml

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,95 @@ defaults:
44
- _self_
55
- override hydra/job_logging: stdout
66

7+
mode: train
8+
9+
paths:
10+
dc2: /data/scratch/dc2local # change for gl
11+
output: /data/scratch/shreyasc/bliss_output # change for gl
12+
713
variational_factors:
814
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
915
name: shear
10-
sample_rearrange: "1 ht wt d -> ht wt 1 d"
11-
nll_rearrange: "b ht wt 1 1 -> b ht wt 1"
1216
nll_gating: null
1317
- _target_: bliss.encoder.variational_dist.NormalFactor
1418
name: convergence
15-
sample_rearrange: "b ht wt -> b ht wt 1 1"
16-
nll_rearrange: "b ht wt 1 1 -> b ht wt"
1719
nll_gating: null
20+
high_clamp: 0.
21+
low_clamp: 0.
1822

19-
prior:
20-
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
21-
n_tiles_h: 48
22-
n_tiles_w: 48
23-
tile_slen: 4
24-
batch_size: 32
25-
prob_galaxy: 1.0
26-
mean_sources: 0.2
27-
arcsec_per_pixel: 0.396
28-
sample_method: cosmology
29-
shear_mean: 0
30-
shear_std: 0.02
31-
convergence_mean: 0
32-
convergence_std: 0.02
33-
num_knots: 2
23+
my_normalizers:
24+
# asinh:
25+
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
26+
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
27+
# stride: 4
28+
nully:
29+
_target_: bliss.encoder.image_normalizer.NullNormalizer
3430

35-
simulator:
36-
_target_: case_studies.weak_lensing.lensing_simulated_dataset.LensingSimulatedDataset
31+
my_metrics:
32+
lensing_map:
33+
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
3734

38-
generate:
39-
n_image_files: 8
40-
n_batches_per_file: 16
41-
cached_data_path: /data/scratch/shreyas/weak_lensing_cosmology_prior
35+
my_render:
36+
lensing_shear_conv:
37+
_target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
38+
frequency: 1
39+
restrict_batch: 0
40+
tile_slen: 256
41+
save_local: "convergence_only_maps"
4242

43-
cached_simulator:
44-
cached_data_path: /data/scratch/shreyas/weak_lensing_cosmology_prior
45-
batch_size: 24
4643

4744
encoder:
48-
metrics:
45+
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
46+
survey_bands: ["u", "g", "r", "i", "z", "y"]
47+
reference_band: 2 # r-band
48+
tile_slen: 256
49+
optimizer_params:
50+
lr: 1e-2
51+
scheduler_params:
52+
milestones: [32]
53+
gamma: 0.1
54+
image_normalizers: ${my_normalizers}
55+
56+
var_dist:
57+
_target_: bliss.encoder.variational_dist.VariationalDist
58+
tile_slen: ${encoder.tile_slen}
59+
factors: ${variational_factors}
60+
mode_metrics:
61+
_target_: torchmetrics.MetricCollection
62+
_convert_: partial
63+
metrics: ${my_metrics}
64+
sample_metrics:
4965
_target_: torchmetrics.MetricCollection
50-
metrics:
51-
lensing_map:
52-
_target_: case_studies.weak_lensing.lensing_metrics.LensingMapMSE
66+
_convert_: partial
67+
metrics: ${my_metrics}
5368
sample_image_renders:
5469
_target_: torchmetrics.MetricCollection
55-
metrics:
56-
- _target_: bliss.encoder.sample_image_renders.PlotSampleImages
57-
frequency: 1
58-
restrict_batch: 0
59-
tiles_to_crop: 0
60-
tile_slen: ${simulator.decoder.tile_slen}
61-
- _target_: case_studies.weak_lensing.lensing_plots.PlotWeakLensingShearConvergence
62-
frequency: 1
63-
restrict_batch: 0
64-
tiles_to_crop: 0
65-
tile_slen: ${simulator.decoder.tile_slen}
70+
_convert_: partial
71+
metrics: ${my_render}
6672
use_double_detect: false
6773
use_checkerboard: false
74+
train_loss_location: "train_loss_plt"
6875

6976
surveys:
70-
sdss:
71-
_target_: bliss.surveys.sdss.SloanDigitalSkySurvey
72-
dir_path: ${paths.sdss}
73-
fields: # TODO: better arbitary name for fields/bricks?
74-
- run: 2334
75-
camcol: 6
76-
fields: [13]
77-
psf_config:
78-
pixel_scale: 0.396
79-
psf_slen: 25
80-
align_to_band: null
81-
load_image_data: true
77+
dc2:
78+
_target_: case_studies.weak_lensing.lensing_dc2.LensingDC2DataModule
79+
dc2_image_dir: ${paths.dc2}/run2.2i-dr6-v4/coadd-t3828-t3829/deepCoadd-results/
80+
dc2_cat_path: ${paths.dc2}/lensing_catalog.pkl
81+
image_slen: 2048
82+
tile_slen: 256
83+
splits: 0:80/80:90/90:100
84+
batch_size: 1
85+
num_workers: 1
86+
cached_data_path: ${paths.output}/dc2_2048_galid_full_scaled_up
87+
88+
train:
89+
trainer:
90+
logger:
91+
name: dc2_weak_lensing_exp
92+
version: exp_08_05
93+
devices: [6] # cuda:0 for gl
94+
use_distributed_sampler: false
95+
precision: 32-true
96+
data_source: ${surveys.dc2}
97+
pretrained_weights: null
98+
seed: 123123
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from torch import nn
2+
3+
from bliss.encoder.convnet_layers import C3, ConvBlock, Detect
4+
5+
6+
class WeakLensingFeaturesNet(nn.Module):
7+
def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
8+
super().__init__()
9+
10+
nch_hidden = 64
11+
self.preprocess3d = nn.Sequential(
12+
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]),
13+
nn.GroupNorm(
14+
num_groups=32, num_channels=nch_hidden
15+
), # sqrt of num channels, get rid of it, even shallower
16+
nn.SiLU(),
17+
)
18+
19+
# TODO: adaptive downsample
20+
self.n_downsample = 1
21+
22+
module_list = []
23+
24+
for _ in range(self.n_downsample):
25+
module_list.append(
26+
ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2, padding=2)
27+
)
28+
nch_hidden *= 2
29+
30+
module_list.extend(
31+
[
32+
ConvBlock(nch_hidden, 64, kernel_size=5, padding=2),
33+
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5, padding=2) for _ in range(1)]),
34+
ConvBlock(64, 128, stride=2),
35+
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]),
36+
ConvBlock(128, num_features, stride=1),
37+
]
38+
) # 4
39+
40+
self.net = nn.ModuleList(module_list)
41+
42+
def forward(self, x):
43+
x = self.preprocess3d(x).squeeze(2)
44+
for _i, m in enumerate(self.net):
45+
x = m(x)
46+
47+
return x
48+
49+
50+
class WeakLensingCatalogNet(nn.Module):
51+
def __init__(self, in_channels, out_channels):
52+
super().__init__()
53+
54+
net_layers = [
55+
C3(in_channels, 256, n=1, shortcut=True), # 0
56+
ConvBlock(256, 512, stride=2),
57+
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection
58+
ConvBlock(
59+
in_channels=256, out_channels=256, kernel_size=3, stride=8, padding=1
60+
), # (1, 256, 128, 128)
61+
ConvBlock(
62+
in_channels=256, out_channels=256, kernel_size=3, stride=4, padding=1
63+
), # (1, 256, 8, 8)
64+
Detect(256, out_channels),
65+
]
66+
self.net = nn.ModuleList(net_layers)
67+
68+
def forward(self, x):
69+
for _i, m in enumerate(self.net):
70+
x = m(x)
71+
return x

0 commit comments

Comments
 (0)