Skip to content

Commit 6984255

Browse files
shreyasc30timwhite0Shreyas ChandrashekaranShreyas Chandrashekaran
authored
weak lensing encoder shear updates (#1072)
* Refactor generate_cached_data in lensing_dc2 * Decrease learning rate, remove clamp on convergence stdev * Remove some print statements in lensing_encoder * in progress changes to normalizer, convnet, and encoder, as well as metrics and plots to only estimate shear * new architecture with resnet and resnetx layers as well as prelim changes to support psfasimage with full PSF from limited object table but no tests yet * updated to make shear1 and shear2 separate normal factors * updated lensing config to split up shear_1 and shear_2 as nf * updated network due to OOM * removed print statements from enc * rolled back some debug changes and re-established consistency with master * deduped lensing config * styling tests * style checks update * removed try/catch from cached_datset and made fix to lensing_dc2 * fixed lensing MSE denominator * fixed lensing config --------- Co-authored-by: Tim White <[email protected]> Co-authored-by: Shreyas Chandrashekaran <[email protected]> Co-authored-by: Shreyas Chandrashekaran <[email protected]>
1 parent 0907a16 commit 6984255

File tree

9 files changed

+251
-107
lines changed

9 files changed

+251
-107
lines changed

bliss/encoder/variational_dist.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
171171
def get_dist(self, params):
172172
mean = params[:, :, :, :2]
173173
sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()
174-
175174
return Independent(Normal(mean, sd), 1)
176175

177176

bliss/surveys/dc2.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,21 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
387387
return cls(height, width, d), psf_params, match_id
388388

389389
@classmethod
390-
def get_bands_flux_and_psf(cls, bands, catalog):
390+
def get_bands_flux_and_psf(cls, bands, catalog, median=True):
391391
flux_list = []
392392
psf_params_list = []
393393
for b in bands:
394394
flux_list.append(torch.from_numpy((catalog["flux_" + b]).values))
395395
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
396396
psf_params_cur_band = []
397397
for i in psf_params_name:
398-
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
399-
psf_params_cur_band.append(median_psf)
400-
psf_params_list.append(torch.tensor(psf_params_cur_band))
401-
402-
return torch.stack(flux_list).t(), torch.stack(psf_params_list)
398+
if median:
399+
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
400+
psf_params_cur_band.append(median_psf)
401+
else:
402+
psf_params_cur_band.append(catalog[i + b].values.astype(np.float32))
403+
psf_params_list.append(
404+
torch.tensor(psf_params_cur_band)
405+
) # bands x 4 (params per band) x n_obj
406+
407+
return torch.stack(flux_list).t(), torch.stack(psf_params_list).unsqueeze(0)

case_studies/weak_lensing/lensing_config.yaml

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ defaults:
77
mode: train
88

99
paths:
10+
1011
dc2: /data/scratch/dc2local
11-
cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_shear02
12-
output: /data/scratch/twhit/bliss_output
12+
output: /data/scratch/shreyasc/bliss_output
1313

1414
prior:
1515
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
@@ -34,20 +34,26 @@ cached_simulator:
3434
train_transforms: []
3535

3636
variational_factors:
37-
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
38-
name: shear
37+
- _target_: bliss.encoder.variational_dist.NormalFactor
38+
name: shear_1
3939
nll_gating: null
4040
- _target_: bliss.encoder.variational_dist.NormalFactor
41-
name: convergence
41+
name: shear_2
4242
nll_gating: null
43-
high_clamp: 20.0
44-
low_clamp: -20.0
43+
# - _target_: bliss.encoder.variational_dist.BivariateNormalFactor
44+
# name: shear
45+
# nll_gating: null
46+
# - _target_: bliss.encoder.variational_dist.NormalFactor
47+
# name: convergence
48+
# nll_gating: null
49+
# high_clamp: 20.0
50+
# low_clamp: -20.0
4551

4652
my_normalizers:
4753
# asinh:
4854
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
4955
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
50-
# stride: 4
56+
# sample_every_n: 4
5157
nully:
5258
_target_: bliss.encoder.image_normalizer.NullNormalizer
5359

@@ -61,13 +67,15 @@ my_render:
6167
frequency: 1
6268
restrict_batch: 0
6369
tile_slen: 256
64-
save_local: "lensing_maps"
70+
save_local: lensing_maps
6571

6672
encoder:
6773
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
6874
survey_bands: ["u", "g", "r", "i", "z", "y"]
6975
reference_band: 2 # r-band
7076
tile_slen: 256
77+
n_tiles: 8
78+
nch_hidden: 64
7179
optimizer_params:
7280
lr: 1e-3
7381
scheduler_params:
@@ -93,7 +101,7 @@ encoder:
93101
metrics: ${my_render}
94102
use_double_detect: false
95103
use_checkerboard: false
96-
train_loss_location: "train_loss_plt"
104+
train_loss_location: train_loss
97105

98106
surveys:
99107
dc2:
@@ -108,8 +116,19 @@ surveys:
108116
avg_ellip_kernel_sigma: 3
109117
batch_size: 1
110118
num_workers: 1
111-
cached_data_path: ${paths.dc2}/dc2_lensing_splits_img2048_tile256
119+
cached_data_path: ${paths.output}/dc2_corrected_shear_only_cd_fix
112120

113-
generate:
114-
n_image_files: 50
115-
n_batches_per_file: 4
121+
# generate:
122+
# n_image_files: 50
123+
# n_batches_per_file: 4
124+
train:
125+
trainer:
126+
logger:
127+
name: dc2_weak_lensing_exp
128+
version: exp_09_16
129+
devices: [0] # cuda:0 for gl
130+
use_distributed_sampler: false
131+
precision: 32-true
132+
data_source: ${surveys.dc2}
133+
pretrained_weights: null
134+
seed: 123123
Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import math
2+
13
from torch import nn
24

3-
from bliss.encoder.convnet_layers import C3, ConvBlock, Detect
5+
from bliss.encoder.convnet_layers import Detect
6+
from case_studies.weak_lensing.lensing_convnet_layers import RN2Block
47

58

69
class WeakLensingFeaturesNet(nn.Module):
7-
def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
10+
def __init__(self, n_bands, ch_per_band, num_features, tile_slen, nch_hidden):
811
super().__init__()
912

10-
nch_hidden = 64
1113
self.preprocess3d = nn.Sequential(
1214
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]),
1315
nn.GroupNorm(
@@ -16,52 +18,50 @@ def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
1618
nn.SiLU(),
1719
)
1820

19-
# TODO: adaptive downsample
20-
self.n_downsample = 1
21-
22-
module_list = []
23-
24-
for _ in range(self.n_downsample):
25-
module_list.append(ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2))
26-
nch_hidden *= 2
21+
n_blocks2 = int(math.log2(num_features)) - int(math.log2(nch_hidden))
22+
module_list = [RN2Block(nch_hidden, nch_hidden), RN2Block(nch_hidden, nch_hidden)]
23+
for i in range(n_blocks2):
24+
in_dim = nch_hidden * (2**i)
25+
out_dim = in_dim * 2
2726

28-
module_list.extend(
29-
[
30-
ConvBlock(nch_hidden, 64, kernel_size=5),
31-
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5) for _ in range(1)]),
32-
ConvBlock(64, 128, stride=2),
33-
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]),
34-
ConvBlock(128, num_features, stride=1),
35-
]
36-
) # 4
27+
module_list.append(RN2Block(in_dim, out_dim, stride=2))
28+
module_list.append(RN2Block(out_dim, out_dim))
3729

3830
self.net = nn.ModuleList(module_list)
3931

4032
def forward(self, x):
4133
x = self.preprocess3d(x).squeeze(2)
42-
for _i, m in enumerate(self.net):
43-
x = m(x)
44-
34+
for _idx, layer in enumerate(self.net):
35+
x = layer(x)
4536
return x
4637

4738

48-
class WeakLensingCatalogNet(nn.Module):
49-
def __init__(self, in_channels, out_channels):
39+
class WeakLensingCatalogNet(nn.Module): # TODO: get the dimensions down to n_tiles
40+
def __init__(self, in_channels, out_channels, n_tiles):
5041
super().__init__()
5142

52-
net_layers = [
53-
C3(in_channels, 256, n=1, shortcut=True), # 0
54-
ConvBlock(256, 512, stride=2),
55-
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection
56-
ConvBlock(
57-
in_channels=256, out_channels=256, kernel_size=3, stride=8
58-
), # (1, 256, 128, 128)
59-
ConvBlock(in_channels=256, out_channels=256, kernel_size=3, stride=4), # (1, 256, 8, 8)
60-
Detect(256, out_channels),
61-
]
43+
net_layers = []
44+
45+
n_blocks2 = int(math.log2(in_channels)) - int(math.ceil(math.log2(out_channels)))
46+
last_out_dim = -1
47+
for i in range(n_blocks2):
48+
in_dim = in_channels // (2**i)
49+
out_dim = in_dim // 2
50+
if i < ((n_blocks2 + 4) // 2):
51+
net_layers.append(RN2Block(in_dim, out_dim, stride=2))
52+
else:
53+
net_layers.append(RN2Block(in_dim, out_dim))
54+
last_out_dim = out_dim
55+
56+
# Final detection layer to reduce channels
57+
self.detect = Detect(last_out_dim, out_channels)
6258
self.net = nn.ModuleList(net_layers)
6359

6460
def forward(self, x):
6561
for _i, m in enumerate(self.net):
6662
x = m(x)
67-
return x
63+
64+
# Final detection layer
65+
x = self.detect(x)
66+
67+
return x # noqa: WPS331
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import math
2+
3+
from torch import nn
4+
5+
6+
class RN2Block(nn.Module):
7+
def __init__(self, in_channels, out_channels, stride=1):
8+
super().__init__()
9+
self.conv1 = nn.Conv2d(
10+
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
11+
)
12+
out_c_sqrt = math.sqrt(out_channels)
13+
if out_c_sqrt.is_integer():
14+
n_groups = int(out_c_sqrt)
15+
else:
16+
n_groups = int(
17+
math.sqrt(out_channels * 2)
18+
) # even powers of 2 guaranteed to be perfect squares
19+
self.gn1 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
20+
self.silu = nn.SiLU(inplace=True)
21+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
22+
self.gn2 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
23+
self.downsample = None
24+
if stride != 1 or in_channels != out_channels:
25+
self.downsample = nn.Sequential(
26+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
27+
nn.GroupNorm(num_groups=n_groups, num_channels=out_channels),
28+
)
29+
30+
def forward(self, x):
31+
identity = x
32+
33+
out = self.conv1(x)
34+
out = self.gn1(out)
35+
out = self.silu(out)
36+
37+
out = self.conv2(out)
38+
out = self.gn2(out)
39+
40+
if self.downsample:
41+
identity = self.downsample(x)
42+
43+
out += identity
44+
out = self.silu(out)
45+
46+
return out # noqa: WPS331
47+
48+
49+
class ResNeXtBlock(nn.Module):
50+
def __init__(self, in_channels, mid_channels, out_channels, stride=1, groups=32):
51+
super().__init__()
52+
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
53+
mid_c_sqrt = math.sqrt(mid_channels)
54+
if mid_c_sqrt.is_integer():
55+
mid_norm_n_groups = int(mid_c_sqrt)
56+
else:
57+
mid_norm_n_groups = int(
58+
math.sqrt(mid_channels * 2)
59+
) # even powers of 2 guaranteed to be perfect squares
60+
self.gn1 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
61+
self.conv2 = nn.Conv2d(
62+
mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=groups
63+
)
64+
self.gn2 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
65+
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)
66+
out_c_sqrt = math.sqrt(out_channels)
67+
if out_c_sqrt.is_integer():
68+
out_norm_n_groups = int(out_c_sqrt)
69+
else:
70+
out_norm_n_groups = int(
71+
math.sqrt(out_channels * 2)
72+
) # even powers of 2 guaranteed to be perfect squares
73+
self.gn3 = nn.GroupNorm(num_groups=out_norm_n_groups, num_channels=out_channels)
74+
self.silu = nn.SiLU(inplace=True)
75+
76+
# Adjust the shortcut connection to match the output dimensions
77+
self.shortcut = None
78+
if stride != 1 or in_channels != out_channels:
79+
self.shortcut = nn.Sequential(
80+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
81+
nn.GroupNorm(out_channels),
82+
)
83+
84+
def forward(self, x):
85+
residual = x
86+
out = self.conv1(x)
87+
out = self.gn1(out)
88+
out = self.silu(out)
89+
out = self.conv2(out)
90+
out = self.gn2(out)
91+
out = self.silu(out)
92+
out = self.conv3(out)
93+
out = self.gn3(out)
94+
if self.shortcut:
95+
residual = self.shortcut(x)
96+
out += residual
97+
out = self.silu(out)
98+
return out # noqa: WPS331

0 commit comments

Comments
 (0)