Skip to content

Commit

Permalink
weak lensing encoder shear updates (#1072)
Browse files Browse the repository at this point in the history
* Refactor generate_cached_data in lensing_dc2

* Decrease learning rate, remove clamp on convergence stdev

* Remove some print statements in lensing_encoder

* in progress changes to normalizer, convnet, and encoder, as well as metrics and plots to only estimate shear

* new architecture with resnet and resnetx layers as well as prelim changes to support psfasimage with full PSF from limited object table but no tests yet

* updated to make shear1 and shear2 separate normal factors

* updated lensing config to split up shear_1 and shear_2 as nf

* updated network due to OOM

* removed print statements from enc

* rolled back some debug changes and re-established consistency with master

* deduped lensing config

* styling tests

* style checks update

* removed try/catch from cached_datset and made fix to lensing_dc2

* fixed lensing MSE denominator

* fixed lensing config

---------

Co-authored-by: Tim White <[email protected]>
Co-authored-by: Shreyas Chandrashekaran <[email protected]>
Co-authored-by: Shreyas Chandrashekaran <[email protected]>
  • Loading branch information
4 people authored Sep 18, 2024
1 parent 0907a16 commit 6984255
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 107 deletions.
1 change: 0 additions & 1 deletion bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
def get_dist(self, params):
mean = params[:, :, :, :2]
sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()

return Independent(Normal(mean, sd), 1)


Expand Down
17 changes: 11 additions & 6 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,21 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs):
return cls(height, width, d), psf_params, match_id

@classmethod
def get_bands_flux_and_psf(cls, bands, catalog):
def get_bands_flux_and_psf(cls, bands, catalog, median=True):
flux_list = []
psf_params_list = []
for b in bands:
flux_list.append(torch.from_numpy((catalog["flux_" + b]).values))
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
psf_params_cur_band = []
for i in psf_params_name:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_cur_band.append(median_psf)
psf_params_list.append(torch.tensor(psf_params_cur_band))

return torch.stack(flux_list).t(), torch.stack(psf_params_list)
if median:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_cur_band.append(median_psf)
else:
psf_params_cur_band.append(catalog[i + b].values.astype(np.float32))
psf_params_list.append(
torch.tensor(psf_params_cur_band)
) # bands x 4 (params per band) x n_obj

return torch.stack(flux_list).t(), torch.stack(psf_params_list).unsqueeze(0)
47 changes: 33 additions & 14 deletions case_studies/weak_lensing/lensing_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ defaults:
mode: train

paths:

dc2: /data/scratch/dc2local
cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_shear02
output: /data/scratch/twhit/bliss_output
output: /data/scratch/shreyasc/bliss_output

prior:
_target_: case_studies.weak_lensing.lensing_prior.LensingPrior
Expand All @@ -34,20 +34,26 @@ cached_simulator:
train_transforms: []

variational_factors:
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
name: shear
- _target_: bliss.encoder.variational_dist.NormalFactor
name: shear_1
nll_gating: null
- _target_: bliss.encoder.variational_dist.NormalFactor
name: convergence
name: shear_2
nll_gating: null
high_clamp: 20.0
low_clamp: -20.0
# - _target_: bliss.encoder.variational_dist.BivariateNormalFactor
# name: shear
# nll_gating: null
# - _target_: bliss.encoder.variational_dist.NormalFactor
# name: convergence
# nll_gating: null
# high_clamp: 20.0
# low_clamp: -20.0

my_normalizers:
# asinh:
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
# stride: 4
# sample_every_n: 4
nully:
_target_: bliss.encoder.image_normalizer.NullNormalizer

Expand All @@ -61,13 +67,15 @@ my_render:
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: "lensing_maps"
save_local: lensing_maps

encoder:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
survey_bands: ["u", "g", "r", "i", "z", "y"]
reference_band: 2 # r-band
tile_slen: 256
n_tiles: 8
nch_hidden: 64
optimizer_params:
lr: 1e-3
scheduler_params:
Expand All @@ -93,7 +101,7 @@ encoder:
metrics: ${my_render}
use_double_detect: false
use_checkerboard: false
train_loss_location: "train_loss_plt"
train_loss_location: train_loss

surveys:
dc2:
Expand All @@ -108,8 +116,19 @@ surveys:
avg_ellip_kernel_sigma: 3
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits_img2048_tile256
cached_data_path: ${paths.output}/dc2_corrected_shear_only_cd_fix

generate:
n_image_files: 50
n_batches_per_file: 4
# generate:
# n_image_files: 50
# n_batches_per_file: 4
train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_09_16
devices: [0] # cuda:0 for gl
use_distributed_sampler: false
precision: 32-true
data_source: ${surveys.dc2}
pretrained_weights: null
seed: 123123
72 changes: 36 additions & 36 deletions case_studies/weak_lensing/lensing_convnet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import math

from torch import nn

from bliss.encoder.convnet_layers import C3, ConvBlock, Detect
from bliss.encoder.convnet_layers import Detect
from case_studies.weak_lensing.lensing_convnet_layers import RN2Block


class WeakLensingFeaturesNet(nn.Module):
def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
def __init__(self, n_bands, ch_per_band, num_features, tile_slen, nch_hidden):
super().__init__()

nch_hidden = 64
self.preprocess3d = nn.Sequential(
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]),
nn.GroupNorm(
Expand All @@ -16,52 +18,50 @@ def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
nn.SiLU(),
)

# TODO: adaptive downsample
self.n_downsample = 1

module_list = []

for _ in range(self.n_downsample):
module_list.append(ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2))
nch_hidden *= 2
n_blocks2 = int(math.log2(num_features)) - int(math.log2(nch_hidden))
module_list = [RN2Block(nch_hidden, nch_hidden), RN2Block(nch_hidden, nch_hidden)]
for i in range(n_blocks2):
in_dim = nch_hidden * (2**i)
out_dim = in_dim * 2

module_list.extend(
[
ConvBlock(nch_hidden, 64, kernel_size=5),
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5) for _ in range(1)]),
ConvBlock(64, 128, stride=2),
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]),
ConvBlock(128, num_features, stride=1),
]
) # 4
module_list.append(RN2Block(in_dim, out_dim, stride=2))
module_list.append(RN2Block(out_dim, out_dim))

self.net = nn.ModuleList(module_list)

def forward(self, x):
x = self.preprocess3d(x).squeeze(2)
for _i, m in enumerate(self.net):
x = m(x)

for _idx, layer in enumerate(self.net):
x = layer(x)
return x


class WeakLensingCatalogNet(nn.Module):
def __init__(self, in_channels, out_channels):
class WeakLensingCatalogNet(nn.Module): # TODO: get the dimensions down to n_tiles
def __init__(self, in_channels, out_channels, n_tiles):
super().__init__()

net_layers = [
C3(in_channels, 256, n=1, shortcut=True), # 0
ConvBlock(256, 512, stride=2),
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection
ConvBlock(
in_channels=256, out_channels=256, kernel_size=3, stride=8
), # (1, 256, 128, 128)
ConvBlock(in_channels=256, out_channels=256, kernel_size=3, stride=4), # (1, 256, 8, 8)
Detect(256, out_channels),
]
net_layers = []

n_blocks2 = int(math.log2(in_channels)) - int(math.ceil(math.log2(out_channels)))
last_out_dim = -1
for i in range(n_blocks2):
in_dim = in_channels // (2**i)
out_dim = in_dim // 2
if i < ((n_blocks2 + 4) // 2):
net_layers.append(RN2Block(in_dim, out_dim, stride=2))
else:
net_layers.append(RN2Block(in_dim, out_dim))
last_out_dim = out_dim

# Final detection layer to reduce channels
self.detect = Detect(last_out_dim, out_channels)
self.net = nn.ModuleList(net_layers)

def forward(self, x):
for _i, m in enumerate(self.net):
x = m(x)
return x

# Final detection layer
x = self.detect(x)

return x # noqa: WPS331
98 changes: 98 additions & 0 deletions case_studies/weak_lensing/lensing_convnet_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import math

from torch import nn


class RN2Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
)
out_c_sqrt = math.sqrt(out_channels)
if out_c_sqrt.is_integer():
n_groups = int(out_c_sqrt)
else:
n_groups = int(
math.sqrt(out_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn1 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
self.silu = nn.SiLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.gn2 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.GroupNorm(num_groups=n_groups, num_channels=out_channels),
)

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.gn1(out)
out = self.silu(out)

out = self.conv2(out)
out = self.gn2(out)

if self.downsample:
identity = self.downsample(x)

out += identity
out = self.silu(out)

return out # noqa: WPS331


class ResNeXtBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, stride=1, groups=32):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
mid_c_sqrt = math.sqrt(mid_channels)
if mid_c_sqrt.is_integer():
mid_norm_n_groups = int(mid_c_sqrt)
else:
mid_norm_n_groups = int(
math.sqrt(mid_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn1 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
self.conv2 = nn.Conv2d(
mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=groups
)
self.gn2 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)
out_c_sqrt = math.sqrt(out_channels)
if out_c_sqrt.is_integer():
out_norm_n_groups = int(out_c_sqrt)
else:
out_norm_n_groups = int(
math.sqrt(out_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn3 = nn.GroupNorm(num_groups=out_norm_n_groups, num_channels=out_channels)
self.silu = nn.SiLU(inplace=True)

# Adjust the shortcut connection to match the output dimensions
self.shortcut = None
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
nn.GroupNorm(out_channels),
)

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.gn1(out)
out = self.silu(out)
out = self.conv2(out)
out = self.gn2(out)
out = self.silu(out)
out = self.conv3(out)
out = self.gn3(out)
if self.shortcut:
residual = self.shortcut(x)
out += residual
out = self.silu(out)
return out # noqa: WPS331
Loading

0 comments on commit 6984255

Please sign in to comment.