Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weak lensing encoder shear updates #1072

Merged
merged 19 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@
self.buffered_file_index = converted_index
with open(self.file_paths[converted_index], "rb") as f:
self.buffered_data = torch.load(f)
output_data = self.buffered_data[converted_sub_index]
try:
output_data = self.buffered_data[converted_sub_index]
except KeyError:
output_data = self.buffered_data

Check warning on line 184 in bliss/cached_dataset.py

View check run for this annotation

Codecov / codecov/patch

bliss/cached_dataset.py#L183-L184

Added lines #L183 - L184 were not covered by tests
return self.transform(output_data)

def get_chunked_indices(self):
Expand Down
1 change: 0 additions & 1 deletion bliss/encoder/variational_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs):
def get_dist(self, params):
mean = params[:, :, :, :2]
sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt()

return Independent(Normal(mean, sd), 1)


Expand Down
17 changes: 11 additions & 6 deletions bliss/surveys/dc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,21 @@
return cls(height, width, d), psf_params, match_id

@classmethod
def get_bands_flux_and_psf(cls, bands, catalog):
def get_bands_flux_and_psf(cls, bands, catalog, median=True):
flux_list = []
psf_params_list = []
for b in bands:
flux_list.append(torch.from_numpy((catalog["flux_" + b]).values))
psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"]
psf_params_cur_band = []
for i in psf_params_name:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_cur_band.append(median_psf)
psf_params_list.append(torch.tensor(psf_params_cur_band))

return torch.stack(flux_list).t(), torch.stack(psf_params_list)
if median:
median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32)
psf_params_cur_band.append(median_psf)
else:
psf_params_cur_band.append(catalog[i + b].values.astype(np.float32))

Check warning on line 402 in bliss/surveys/dc2.py

View check run for this annotation

Codecov / codecov/patch

bliss/surveys/dc2.py#L402

Added line #L402 was not covered by tests
psf_params_list.append(
torch.tensor(psf_params_cur_band)
) # bands x 4 (params per band) x n_obj

return torch.stack(flux_list).t(), torch.stack(psf_params_list).unsqueeze(0)
39 changes: 30 additions & 9 deletions case_studies/weak_lensing/lensing_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ defaults:
mode: train

paths:
dc2: /scratch/regier_root/regier0/shreyasc/data/dc2local # change for gl
output: /scratch/regier_root/regier0/shreyasc/data/bliss_output # change for gl
dc2: /data/scratch/dc2local
cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_shear02
output: /data/scratch/twhit/bliss_output
Expand Down Expand Up @@ -34,20 +36,26 @@ cached_simulator:
train_transforms: []

variational_factors:
- _target_: bliss.encoder.variational_dist.BivariateNormalFactor
name: shear
- _target_: bliss.encoder.variational_dist.NormalFactor
name: shear_1
nll_gating: null
- _target_: bliss.encoder.variational_dist.NormalFactor
name: convergence
name: shear_2
nll_gating: null
high_clamp: 20.0
low_clamp: -20.0
# - _target_: bliss.encoder.variational_dist.BivariateNormalFactor
# name: shear
# nll_gating: null
# - _target_: bliss.encoder.variational_dist.NormalFactor
# name: convergence
# nll_gating: null
# high_clamp: 20.0
# low_clamp: -20.0

my_normalizers:
# asinh:
# _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer
# q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999]
# stride: 4
# sample_every_n: 4
nully:
_target_: bliss.encoder.image_normalizer.NullNormalizer

Expand All @@ -61,13 +69,15 @@ my_render:
frequency: 1
restrict_batch: 0
tile_slen: 256
save_local: "lensing_maps"
save_local: lensing_maps

encoder:
_target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder
survey_bands: ["u", "g", "r", "i", "z", "y"]
reference_band: 2 # r-band
tile_slen: 256
n_tiles: 8
nch_hidden: 64
optimizer_params:
lr: 1e-3
scheduler_params:
Expand All @@ -93,7 +103,7 @@ encoder:
metrics: ${my_render}
use_double_detect: false
use_checkerboard: false
train_loss_location: "train_loss_plt"
train_loss_location: train_loss

surveys:
dc2:
Expand All @@ -108,8 +118,19 @@ surveys:
avg_ellip_kernel_sigma: 3
batch_size: 1
num_workers: 1
cached_data_path: ${paths.dc2}/dc2_lensing_splits_img2048_tile256
cached_data_path: ${paths.output}/dc2_corrected_shear_only

generate:
n_image_files: 50
n_batches_per_file: 4
train:
trainer:
logger:
name: dc2_weak_lensing_exp
version: exp_09_16
devices: [0] # cuda:0 for gl
use_distributed_sampler: false
precision: 32-true
data_source: ${surveys.dc2}
pretrained_weights: null
seed: 123123
72 changes: 36 additions & 36 deletions case_studies/weak_lensing/lensing_convnet.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import math

from torch import nn

from bliss.encoder.convnet_layers import C3, ConvBlock, Detect
from bliss.encoder.convnet_layers import Detect
from case_studies.weak_lensing.lensing_convnet_layers import RN2Block


class WeakLensingFeaturesNet(nn.Module):
def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
def __init__(self, n_bands, ch_per_band, num_features, tile_slen, nch_hidden):
super().__init__()

nch_hidden = 64
self.preprocess3d = nn.Sequential(
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]),
nn.GroupNorm(
Expand All @@ -16,52 +18,50 @@ def __init__(self, n_bands, ch_per_band, num_features, tile_slen):
nn.SiLU(),
)

# TODO: adaptive downsample
self.n_downsample = 1

module_list = []

for _ in range(self.n_downsample):
module_list.append(ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2))
nch_hidden *= 2
n_blocks2 = int(math.log2(num_features)) - int(math.log2(nch_hidden))
module_list = [RN2Block(nch_hidden, nch_hidden), RN2Block(nch_hidden, nch_hidden)]
for i in range(n_blocks2):
in_dim = nch_hidden * (2**i)
out_dim = in_dim * 2

module_list.extend(
[
ConvBlock(nch_hidden, 64, kernel_size=5),
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5) for _ in range(1)]),
ConvBlock(64, 128, stride=2),
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]),
ConvBlock(128, num_features, stride=1),
]
) # 4
module_list.append(RN2Block(in_dim, out_dim, stride=2))
module_list.append(RN2Block(out_dim, out_dim))

self.net = nn.ModuleList(module_list)

def forward(self, x):
x = self.preprocess3d(x).squeeze(2)
for _i, m in enumerate(self.net):
x = m(x)

for _idx, layer in enumerate(self.net):
x = layer(x)
return x


class WeakLensingCatalogNet(nn.Module):
def __init__(self, in_channels, out_channels):
class WeakLensingCatalogNet(nn.Module): # TODO: get the dimensions down to n_tiles
def __init__(self, in_channels, out_channels, n_tiles):
super().__init__()

net_layers = [
C3(in_channels, 256, n=1, shortcut=True), # 0
ConvBlock(256, 512, stride=2),
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection
ConvBlock(
in_channels=256, out_channels=256, kernel_size=3, stride=8
), # (1, 256, 128, 128)
ConvBlock(in_channels=256, out_channels=256, kernel_size=3, stride=4), # (1, 256, 8, 8)
Detect(256, out_channels),
]
net_layers = []

n_blocks2 = int(math.log2(in_channels)) - int(math.ceil(math.log2(out_channels)))
last_out_dim = -1
for i in range(n_blocks2):
in_dim = in_channels // (2**i)
out_dim = in_dim // 2
if i < ((n_blocks2 + 4) // 2):
net_layers.append(RN2Block(in_dim, out_dim, stride=2))
else:
net_layers.append(RN2Block(in_dim, out_dim))
last_out_dim = out_dim

# Final detection layer to reduce channels
self.detect = Detect(last_out_dim, out_channels)
self.net = nn.ModuleList(net_layers)

def forward(self, x):
for _i, m in enumerate(self.net):
x = m(x)
return x

# Final detection layer
x = self.detect(x)

return x # noqa: WPS331
98 changes: 98 additions & 0 deletions case_studies/weak_lensing/lensing_convnet_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import math

from torch import nn


class RN2Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
)
out_c_sqrt = math.sqrt(out_channels)
if out_c_sqrt.is_integer():
n_groups = int(out_c_sqrt)
else:
n_groups = int(
math.sqrt(out_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn1 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
self.silu = nn.SiLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.gn2 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels)
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.GroupNorm(num_groups=n_groups, num_channels=out_channels),
)

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.gn1(out)
out = self.silu(out)

out = self.conv2(out)
out = self.gn2(out)

if self.downsample:
identity = self.downsample(x)

out += identity
out = self.silu(out)

return out # noqa: WPS331


class ResNeXtBlock(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, stride=1, groups=32):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
mid_c_sqrt = math.sqrt(mid_channels)
if mid_c_sqrt.is_integer():
mid_norm_n_groups = int(mid_c_sqrt)
else:
mid_norm_n_groups = int(
math.sqrt(mid_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn1 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
self.conv2 = nn.Conv2d(
mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=groups
)
self.gn2 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0)
out_c_sqrt = math.sqrt(out_channels)
if out_c_sqrt.is_integer():
out_norm_n_groups = int(out_c_sqrt)
else:
out_norm_n_groups = int(
math.sqrt(out_channels * 2)
) # even powers of 2 guaranteed to be perfect squares
self.gn3 = nn.GroupNorm(num_groups=out_norm_n_groups, num_channels=out_channels)
self.silu = nn.SiLU(inplace=True)

# Adjust the shortcut connection to match the output dimensions
self.shortcut = None
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0),
nn.GroupNorm(out_channels),
)

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.gn1(out)
out = self.silu(out)
out = self.conv2(out)
out = self.gn2(out)
out = self.silu(out)
out = self.conv3(out)
out = self.gn3(out)
if self.shortcut:
residual = self.shortcut(x)
out += residual
out = self.silu(out)
return out # noqa: WPS331
Loading
Loading