From 6984255d6827621f5598d5fe69c16669d2b4bf23 Mon Sep 17 00:00:00 2001 From: Shreyas Date: Wed, 18 Sep 2024 16:01:01 -0400 Subject: [PATCH] weak lensing encoder shear updates (#1072) * Refactor generate_cached_data in lensing_dc2 * Decrease learning rate, remove clamp on convergence stdev * Remove some print statements in lensing_encoder * in progress changes to normalizer, convnet, and encoder, as well as metrics and plots to only estimate shear * new architecture with resnet and resnetx layers as well as prelim changes to support psfasimage with full PSF from limited object table but no tests yet * updated to make shear1 and shear2 separate normal factors * updated lensing config to split up shear_1 and shear_2 as nf * updated network due to OOM * removed print statements from enc * rolled back some debug changes and re-established consistency with master * deduped lensing config * styling tests * style checks update * removed try/catch from cached_datset and made fix to lensing_dc2 * fixed lensing MSE denominator * fixed lensing config --------- Co-authored-by: Tim White Co-authored-by: Shreyas Chandrashekaran Co-authored-by: Shreyas Chandrashekaran --- bliss/encoder/variational_dist.py | 1 - bliss/surveys/dc2.py | 17 ++-- case_studies/weak_lensing/lensing_config.yaml | 47 ++++++--- case_studies/weak_lensing/lensing_convnet.py | 72 +++++++------- .../weak_lensing/lensing_convnet_layers.py | 98 +++++++++++++++++++ case_studies/weak_lensing/lensing_dc2.py | 58 +++++------ case_studies/weak_lensing/lensing_encoder.py | 26 ++--- case_studies/weak_lensing/lensing_metrics.py | 22 +++-- case_studies/weak_lensing/lensing_plots.py | 17 +++- 9 files changed, 251 insertions(+), 107 deletions(-) create mode 100644 case_studies/weak_lensing/lensing_convnet_layers.py diff --git a/bliss/encoder/variational_dist.py b/bliss/encoder/variational_dist.py index fde146300..136dc7922 100644 --- a/bliss/encoder/variational_dist.py +++ b/bliss/encoder/variational_dist.py @@ -171,7 +171,6 @@ def __init__(self, *args, low_clamp=-20, high_clamp=20, **kwargs): def get_dist(self, params): mean = params[:, :, :, :2] sd = params[:, :, :, 2:].clamp(self.low_clamp, self.high_clamp).exp().sqrt() - return Independent(Normal(mean, sd), 1) diff --git a/bliss/surveys/dc2.py b/bliss/surveys/dc2.py index bc4b31ab5..cd1aa1bd2 100644 --- a/bliss/surveys/dc2.py +++ b/bliss/surveys/dc2.py @@ -387,7 +387,7 @@ def from_file(cls, cat_path, wcs, height, width, **kwargs): return cls(height, width, d), psf_params, match_id @classmethod - def get_bands_flux_and_psf(cls, bands, catalog): + def get_bands_flux_and_psf(cls, bands, catalog, median=True): flux_list = [] psf_params_list = [] for b in bands: @@ -395,8 +395,13 @@ def get_bands_flux_and_psf(cls, bands, catalog): psf_params_name = ["IxxPSF_pixel_", "IyyPSF_pixel_", "IxyPSF_pixel_", "psf_fwhm_"] psf_params_cur_band = [] for i in psf_params_name: - median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32) - psf_params_cur_band.append(median_psf) - psf_params_list.append(torch.tensor(psf_params_cur_band)) - - return torch.stack(flux_list).t(), torch.stack(psf_params_list) + if median: + median_psf = np.nanmedian((catalog[i + b]).values).astype(np.float32) + psf_params_cur_band.append(median_psf) + else: + psf_params_cur_band.append(catalog[i + b].values.astype(np.float32)) + psf_params_list.append( + torch.tensor(psf_params_cur_band) + ) # bands x 4 (params per band) x n_obj + + return torch.stack(flux_list).t(), torch.stack(psf_params_list).unsqueeze(0) diff --git a/case_studies/weak_lensing/lensing_config.yaml b/case_studies/weak_lensing/lensing_config.yaml index 775817124..635584bd5 100644 --- a/case_studies/weak_lensing/lensing_config.yaml +++ b/case_studies/weak_lensing/lensing_config.yaml @@ -7,9 +7,9 @@ defaults: mode: train paths: + dc2: /data/scratch/dc2local - cached_data: /data/scratch/weak_lensing/weak_lensing_img2048_shear02 - output: /data/scratch/twhit/bliss_output + output: /data/scratch/shreyasc/bliss_output prior: _target_: case_studies.weak_lensing.lensing_prior.LensingPrior @@ -34,20 +34,26 @@ cached_simulator: train_transforms: [] variational_factors: - - _target_: bliss.encoder.variational_dist.BivariateNormalFactor - name: shear + - _target_: bliss.encoder.variational_dist.NormalFactor + name: shear_1 nll_gating: null - _target_: bliss.encoder.variational_dist.NormalFactor - name: convergence + name: shear_2 nll_gating: null - high_clamp: 20.0 - low_clamp: -20.0 +# - _target_: bliss.encoder.variational_dist.BivariateNormalFactor +# name: shear +# nll_gating: null +# - _target_: bliss.encoder.variational_dist.NormalFactor +# name: convergence +# nll_gating: null +# high_clamp: 20.0 +# low_clamp: -20.0 my_normalizers: # asinh: # _target_: bliss.encoder.image_normalizer.AsinhQuantileNormalizer # q: [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 0.999, 0.9999, 0.99999] - # stride: 4 + # sample_every_n: 4 nully: _target_: bliss.encoder.image_normalizer.NullNormalizer @@ -61,13 +67,15 @@ my_render: frequency: 1 restrict_batch: 0 tile_slen: 256 - save_local: "lensing_maps" + save_local: lensing_maps encoder: _target_: case_studies.weak_lensing.lensing_encoder.WeakLensingEncoder survey_bands: ["u", "g", "r", "i", "z", "y"] reference_band: 2 # r-band tile_slen: 256 + n_tiles: 8 + nch_hidden: 64 optimizer_params: lr: 1e-3 scheduler_params: @@ -93,7 +101,7 @@ encoder: metrics: ${my_render} use_double_detect: false use_checkerboard: false - train_loss_location: "train_loss_plt" + train_loss_location: train_loss surveys: dc2: @@ -108,8 +116,19 @@ surveys: avg_ellip_kernel_sigma: 3 batch_size: 1 num_workers: 1 - cached_data_path: ${paths.dc2}/dc2_lensing_splits_img2048_tile256 + cached_data_path: ${paths.output}/dc2_corrected_shear_only_cd_fix -generate: - n_image_files: 50 - n_batches_per_file: 4 +# generate: +# n_image_files: 50 +# n_batches_per_file: 4 +train: + trainer: + logger: + name: dc2_weak_lensing_exp + version: exp_09_16 + devices: [0] # cuda:0 for gl + use_distributed_sampler: false + precision: 32-true + data_source: ${surveys.dc2} + pretrained_weights: null + seed: 123123 diff --git a/case_studies/weak_lensing/lensing_convnet.py b/case_studies/weak_lensing/lensing_convnet.py index f314366ea..7bf4ae930 100644 --- a/case_studies/weak_lensing/lensing_convnet.py +++ b/case_studies/weak_lensing/lensing_convnet.py @@ -1,13 +1,15 @@ +import math + from torch import nn -from bliss.encoder.convnet_layers import C3, ConvBlock, Detect +from bliss.encoder.convnet_layers import Detect +from case_studies.weak_lensing.lensing_convnet_layers import RN2Block class WeakLensingFeaturesNet(nn.Module): - def __init__(self, n_bands, ch_per_band, num_features, tile_slen): + def __init__(self, n_bands, ch_per_band, num_features, tile_slen, nch_hidden): super().__init__() - nch_hidden = 64 self.preprocess3d = nn.Sequential( nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]), nn.GroupNorm( @@ -16,52 +18,50 @@ def __init__(self, n_bands, ch_per_band, num_features, tile_slen): nn.SiLU(), ) - # TODO: adaptive downsample - self.n_downsample = 1 - - module_list = [] - - for _ in range(self.n_downsample): - module_list.append(ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2)) - nch_hidden *= 2 + n_blocks2 = int(math.log2(num_features)) - int(math.log2(nch_hidden)) + module_list = [RN2Block(nch_hidden, nch_hidden), RN2Block(nch_hidden, nch_hidden)] + for i in range(n_blocks2): + in_dim = nch_hidden * (2**i) + out_dim = in_dim * 2 - module_list.extend( - [ - ConvBlock(nch_hidden, 64, kernel_size=5), - nn.Sequential(*[ConvBlock(64, 64, kernel_size=5) for _ in range(1)]), - ConvBlock(64, 128, stride=2), - nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]), - ConvBlock(128, num_features, stride=1), - ] - ) # 4 + module_list.append(RN2Block(in_dim, out_dim, stride=2)) + module_list.append(RN2Block(out_dim, out_dim)) self.net = nn.ModuleList(module_list) def forward(self, x): x = self.preprocess3d(x).squeeze(2) - for _i, m in enumerate(self.net): - x = m(x) - + for _idx, layer in enumerate(self.net): + x = layer(x) return x -class WeakLensingCatalogNet(nn.Module): - def __init__(self, in_channels, out_channels): +class WeakLensingCatalogNet(nn.Module): # TODO: get the dimensions down to n_tiles + def __init__(self, in_channels, out_channels, n_tiles): super().__init__() - net_layers = [ - C3(in_channels, 256, n=1, shortcut=True), # 0 - ConvBlock(256, 512, stride=2), - C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection - ConvBlock( - in_channels=256, out_channels=256, kernel_size=3, stride=8 - ), # (1, 256, 128, 128) - ConvBlock(in_channels=256, out_channels=256, kernel_size=3, stride=4), # (1, 256, 8, 8) - Detect(256, out_channels), - ] + net_layers = [] + + n_blocks2 = int(math.log2(in_channels)) - int(math.ceil(math.log2(out_channels))) + last_out_dim = -1 + for i in range(n_blocks2): + in_dim = in_channels // (2**i) + out_dim = in_dim // 2 + if i < ((n_blocks2 + 4) // 2): + net_layers.append(RN2Block(in_dim, out_dim, stride=2)) + else: + net_layers.append(RN2Block(in_dim, out_dim)) + last_out_dim = out_dim + + # Final detection layer to reduce channels + self.detect = Detect(last_out_dim, out_channels) self.net = nn.ModuleList(net_layers) def forward(self, x): for _i, m in enumerate(self.net): x = m(x) - return x + + # Final detection layer + x = self.detect(x) + + return x # noqa: WPS331 diff --git a/case_studies/weak_lensing/lensing_convnet_layers.py b/case_studies/weak_lensing/lensing_convnet_layers.py new file mode 100644 index 000000000..dbeb29d24 --- /dev/null +++ b/case_studies/weak_lensing/lensing_convnet_layers.py @@ -0,0 +1,98 @@ +import math + +from torch import nn + + +class RN2Block(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False + ) + out_c_sqrt = math.sqrt(out_channels) + if out_c_sqrt.is_integer(): + n_groups = int(out_c_sqrt) + else: + n_groups = int( + math.sqrt(out_channels * 2) + ) # even powers of 2 guaranteed to be perfect squares + self.gn1 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels) + self.silu = nn.SiLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) + self.gn2 = nn.GroupNorm(num_groups=n_groups, num_channels=out_channels) + self.downsample = None + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), + nn.GroupNorm(num_groups=n_groups, num_channels=out_channels), + ) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.gn1(out) + out = self.silu(out) + + out = self.conv2(out) + out = self.gn2(out) + + if self.downsample: + identity = self.downsample(x) + + out += identity + out = self.silu(out) + + return out # noqa: WPS331 + + +class ResNeXtBlock(nn.Module): + def __init__(self, in_channels, mid_channels, out_channels, stride=1, groups=32): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) + mid_c_sqrt = math.sqrt(mid_channels) + if mid_c_sqrt.is_integer(): + mid_norm_n_groups = int(mid_c_sqrt) + else: + mid_norm_n_groups = int( + math.sqrt(mid_channels * 2) + ) # even powers of 2 guaranteed to be perfect squares + self.gn1 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels) + self.conv2 = nn.Conv2d( + mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=groups + ) + self.gn2 = nn.GroupNorm(num_groups=mid_norm_n_groups, num_channels=mid_channels) + self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, padding=0) + out_c_sqrt = math.sqrt(out_channels) + if out_c_sqrt.is_integer(): + out_norm_n_groups = int(out_c_sqrt) + else: + out_norm_n_groups = int( + math.sqrt(out_channels * 2) + ) # even powers of 2 guaranteed to be perfect squares + self.gn3 = nn.GroupNorm(num_groups=out_norm_n_groups, num_channels=out_channels) + self.silu = nn.SiLU(inplace=True) + + # Adjust the shortcut connection to match the output dimensions + self.shortcut = None + if stride != 1 or in_channels != out_channels: + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0), + nn.GroupNorm(out_channels), + ) + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.gn1(out) + out = self.silu(out) + out = self.conv2(out) + out = self.gn2(out) + out = self.silu(out) + out = self.conv3(out) + out = self.gn3(out) + if self.shortcut: + residual = self.shortcut(x) + out += residual + out = self.silu(out) + return out # noqa: WPS331 diff --git a/case_studies/weak_lensing/lensing_dc2.py b/case_studies/weak_lensing/lensing_dc2.py index d92282d17..ecd76ad37 100644 --- a/case_studies/weak_lensing/lensing_dc2.py +++ b/case_studies/weak_lensing/lensing_dc2.py @@ -30,7 +30,6 @@ def __init__( batch_size: int, num_workers: int, cached_data_path: str, - mag_max_cut: float = None, **kwargs, ): super().__init__( @@ -56,7 +55,6 @@ def __init__( self.image_slen = image_slen self.bands = self.BANDS self.n_bands = len(self.BANDS) - self.mag_max_cut = mag_max_cut self.avg_ellip_kernel_size = avg_ellip_kernel_size self.avg_ellip_kernel_sigma = avg_ellip_kernel_sigma @@ -84,7 +82,7 @@ def to_tile_catalog(self, full_catalog, height, width): n_tiles_h = math.ceil(height / self.tile_slen) n_tiles_w = math.ceil(width / self.tile_slen) stti = source_tile_coords[:, :, 0] * n_tiles_w + source_tile_coords[:, :, 1] - source_to_tile_indices = stti.unsqueeze(-1).to(dtype=torch.int64) + source_to_tile_indices_default = stti.unsqueeze(-1).to(dtype=torch.int64) tile_cat = {} @@ -94,23 +92,27 @@ def to_tile_catalog(self, full_catalog, height, width): if k == "plocs": continue - v = v.reshape(self.batch_size, plocs.shape[1], 1) + v = v.reshape(self.batch_size, plocs.shape[1], -1) if k == "mag_mask": continue - v_sum = torch.zeros(self.batch_size, num_tiles, 1, dtype=v.dtype) - v_count = torch.zeros(self.batch_size, num_tiles, 1, dtype=v.dtype) + v_sum = torch.zeros(self.batch_size, num_tiles, v.shape[-1], dtype=v.dtype) + v_count = torch.zeros(self.batch_size, num_tiles, v.shape[-1], dtype=v.dtype) add_pos = torch.ones_like(v) - if k in {"ellip1_lensed", "ellip2_lensed"}: - mag_mask = full_catalog["magnitude_cut_mask"] - v = torch.where(mag_mask, v, torch.tensor(0.0)) - add_pos = mag_mask.float().to(v.dtype) - + if k == "psf": + psf_nanmask = ~torch.isnan(v).any(dim=-1, keepdim=True).expand(-1, -1, v.shape[-1]) + v = torch.where(psf_nanmask, v, torch.tensor(0.0)) + add_pos = psf_nanmask.float().to(v.dtype) + source_to_tile_indices = source_to_tile_indices_default.expand(-1, -1, v.shape[-1]) + else: + source_to_tile_indices = source_to_tile_indices_default v_sum = v_sum.scatter_add(1, source_to_tile_indices, v) v_count = v_count.scatter_add(1, source_to_tile_indices, add_pos) - tile_cat[k + "_sum"] = v_sum.reshape(self.batch_size, n_tiles_w, n_tiles_h, 1) - tile_cat[k + "_count"] = v_count.reshape(self.batch_size, n_tiles_w, n_tiles_h, 1) + tile_cat[k + "_sum"] = v_sum.reshape(self.batch_size, n_tiles_w, n_tiles_h, v.shape[-1]) + tile_cat[k + "_count"] = v_count.reshape( + self.batch_size, n_tiles_w, n_tiles_h, v.shape[-1] + ) return BaseTileCatalog(tile_cat) # override load_image_and_catalog @@ -126,12 +128,14 @@ def load_image_and_catalog(self, image_index): wcs, height, width, - mag_max_cut=self.mag_max_cut, bands=self.bands, n_bands=self.n_bands, ) tile_cat = self.to_tile_catalog(full_cat, height, width) + psf_params = tile_cat["psf_sum"] / tile_cat["psf_count"] + del tile_cat["psf_sum"] + del tile_cat["psf_count"] tile_dict = self.squeeze_tile_dict(tile_cat.data) return { @@ -157,14 +161,14 @@ def generate_cached_data(self, image_index): shear1 = tile_dict["shear1_sum"] / tile_dict["shear1_count"] shear2 = tile_dict["shear2_sum"] / tile_dict["shear2_count"] - shear = torch.stack((shear1.squeeze(-1), shear2.squeeze(-1)), dim=-1) convergence = tile_dict["convergence_sum"] / tile_dict["convergence_count"] ellip1_lensed = tile_dict["ellip1_lensed_sum"] / tile_dict["ellip1_lensed_count"] ellip2_lensed = tile_dict["ellip2_lensed_sum"] / tile_dict["ellip2_lensed_count"] ellip_lensed = torch.stack((ellip1_lensed.squeeze(-1), ellip2_lensed.squeeze(-1)), dim=-1) redshift = tile_dict["redshift_sum"] / tile_dict["redshift_count"] - tile_dict["shear"] = shear + tile_dict["shear_1"] = shear1 + tile_dict["shear_2"] = shear2 tile_dict["convergence"] = convergence tile_dict["ellip_lensed"] = ellip_lensed tile_dict["ellip_lensed_wavg"] = compute_weighted_avg_ellip( @@ -177,18 +181,18 @@ def generate_cached_data(self, image_index): data_to_cache = unpack_dict(data_splits) for i in range(self.n_image_split**2): # noqa: WPS426 - cached_data_file_name = f"cached_data_{image_index:04d}_{i:04d}.pt" + cached_data_file_name = f"cached_data_{image_index:04d}_{i:04d}_size_1.pt" tmp = data_to_cache[i] tmp_clone = map_nested_dicts( tmp, lambda x: x.clone() if isinstance(x, torch.Tensor) else x ) with open(self.cached_data_path / cached_data_file_name, "wb") as cached_data_file: - torch.save(tmp_clone, cached_data_file) + torch.save([tmp_clone], cached_data_file) class LensingDC2Catalog(DC2FullCatalog): @classmethod - def from_file(cls, cat_path, wcs, height, width, mag_max_cut=None, **kwargs): + def from_file(cls, cat_path, wcs, height, width, **kwargs): catalog = pd.read_pickle(cat_path) galid = torch.from_numpy(catalog["galaxy_id"].values) @@ -212,13 +216,8 @@ def from_file(cls, cat_path, wcs, height, width, mag_max_cut=None, **kwargs): redshift = torch.from_numpy(catalog["redshift"].values) - if mag_max_cut: - mag_r = torch.from_numpy(catalog["mag_r"].values) - mag_mask = mag_r < mag_max_cut - else: - mag_mask = torch.ones_like(galid).bool() - - _, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog) + _, psf_params = cls.get_bands_flux_and_psf(kwargs["bands"], catalog, median=False) + # psf_params is n_bands x 4 (n_params) x n_measures plocs = cls.plocs_from_ra_dec(ra, dec, wcs).squeeze(0) x0_mask = (plocs[:, 0] > 0) & (plocs[:, 0] < height) @@ -236,10 +235,11 @@ def from_file(cls, cat_path, wcs, height, width, mag_max_cut=None, **kwargs): redshift = redshift[plocs_mask] - mag_mask = mag_mask[plocs_mask] + psf_params = psf_params[:, :, :, plocs_mask.squeeze() == 1] + psf_params = psf_params.permute(0, 3, 1, 2).flatten(2, -1) # 1, n_obj, 24 nobj = galid.shape[0] - # TODO: pass existant shear & convergence masks in d + d = { "plocs": plocs.reshape(1, nobj, 2), "shear1": shear1.reshape(1, nobj, 1), @@ -247,8 +247,8 @@ def from_file(cls, cat_path, wcs, height, width, mag_max_cut=None, **kwargs): "convergence": convergence.reshape(1, nobj, 1), "ellip1_lensed": ellip1_lensed.reshape(1, nobj, 1), "ellip2_lensed": ellip2_lensed.reshape(1, nobj, 1), - "magnitude_cut_mask": mag_mask.reshape(1, nobj, 1), "redshift": redshift.reshape(1, nobj, 1), + "psf": psf_params.reshape(1, nobj, -1), } return cls(height, width, d), psf_params diff --git a/case_studies/weak_lensing/lensing_encoder.py b/case_studies/weak_lensing/lensing_encoder.py index 02bbc054c..d5a49a076 100644 --- a/case_studies/weak_lensing/lensing_encoder.py +++ b/case_studies/weak_lensing/lensing_encoder.py @@ -15,6 +15,8 @@ def __init__( self, survey_bands: list, tile_slen: int, + n_tiles: int, + nch_hidden: int, image_normalizers: list, var_dist: VariationalDist, sample_image_renders: MetricCollection, @@ -25,6 +27,9 @@ def __init__( reference_band: int = 2, **kwargs, ): + self.n_tiles = n_tiles + self.nch_hidden = nch_hidden + super().__init__( survey_bands=survey_bands, tile_slen=tile_slen, @@ -50,17 +55,20 @@ def __init__( # override def initialize_networks(self): - num_features = 256 + num_features = 512 ch_per_band = sum(inorm.num_channels_per_band() for inorm in self.image_normalizers) self.features_net = WeakLensingFeaturesNet( n_bands=len(self.survey_bands), ch_per_band=ch_per_band, num_features=num_features, tile_slen=self.tile_slen, + nch_hidden=self.nch_hidden, ) + self.catalog_net = WeakLensingCatalogNet( in_channels=num_features, out_channels=self.var_dist.n_params_per_source, + n_tiles=self.n_tiles, ) def sample(self, batch, use_mode=True): @@ -73,6 +81,10 @@ def sample(self, batch, use_mode=True): # est cat return self.var_dist.sample(x_cat_marginal, use_mode=use_mode, return_base_cat=True) + def predict_step(self, batch, batch_idx, dataloader_idx=0): + with torch.no_grad(): + return self.sample(batch, use_mode=True) + def _compute_loss(self, batch, logging_name): batch_size, _, _, _ = batch["images"].shape[0:4] @@ -104,16 +116,8 @@ def on_after_backward(self): param_grad_norm = param.grad.data.norm(2).item() total_grad_norm += param_grad_norm**2 total_grad_norm = total_grad_norm**0.5 - - def configure_gradient_clipping( - self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None - ): - clip_min = 1e-4 - clip_max = 10 - for _name, param in self.named_parameters(): - if param.grad is not None: - with torch.no_grad(): - param.grad.data.clamp_(min=clip_min, max=clip_max) + if total_grad_norm > 100 or total_grad_norm < 1e-4: + print("total_grad_norm", total_grad_norm) # noqa: WPS421 def update_metrics(self, batch, batch_idx): target_cat = BaseTileCatalog(batch["tile_catalog"]) diff --git a/case_studies/weak_lensing/lensing_metrics.py b/case_studies/weak_lensing/lensing_metrics.py index a725e9d65..9f09e066d 100644 --- a/case_studies/weak_lensing/lensing_metrics.py +++ b/case_studies/weak_lensing/lensing_metrics.py @@ -15,14 +15,24 @@ def __init__(self, **kwargs): ) self.add_state("convergence_sum_squared_err", default=torch.zeros(1), dist_reduce_fx="sum") # potentially throws a division by zero error if true_idx is empty and uncaught - self.total = 1 + self.total = torch.tensor(0) def update(self, true_cat, est_cat, matching) -> None: - true_shear = true_cat["shear"].flatten(1, 2) - pred_shear = est_cat["shear"].flatten(1, 2) + true_shear1 = true_cat["shear_1"] + true_shear2 = true_cat["shear_2"] + pred_shear1 = est_cat["shear_1"] + pred_shear2 = est_cat["shear_2"] + true_shear = torch.cat((true_shear1, true_shear2), dim=-1) + pred_shear = torch.cat((pred_shear1, pred_shear2), dim=-1) + true_shear = true_shear.flatten(1, 2) + pred_shear = pred_shear.flatten(1, 2) baseline_pred_shear = true_cat["ellip_lensed"].flatten(1, 2) - true_convergence = true_cat["convergence"].flatten(1, 2) - pred_convergence = est_cat["convergence"].flatten(1, 2) + if "convergence" not in est_cat: + true_convergence = torch.zeros_like(true_shear1).flatten(1, 2) + pred_convergence = torch.zeros_like(true_convergence).flatten(1, 2) + else: + true_convergence = true_cat["convergence"].flatten(1, 2) + pred_convergence = est_cat["convergence"].flatten(1, 2) shear1_sq_err = ((true_shear[:, :, 0] - pred_shear[:, :, 0]) ** 2).sum() baseline_shear1_sq_err = ((true_shear[:, :, 0] - baseline_pred_shear[:, :, 0]) ** 2).sum() @@ -36,7 +46,7 @@ def update(self, true_cat, est_cat, matching) -> None: self.baseline_shear2_sum_squared_err += baseline_shear2_sq_err self.convergence_sum_squared_err += convergence_sq_err - self.total = torch.tensor(true_convergence.shape[1]) + self.total += torch.tensor(true_convergence.shape[1]) def compute(self): shear1_mse = self.shear1_sum_squared_err / self.total diff --git a/case_studies/weak_lensing/lensing_plots.py b/case_studies/weak_lensing/lensing_plots.py index 4f36cdd08..2a1ce9e36 100644 --- a/case_studies/weak_lensing/lensing_plots.py +++ b/case_studies/weak_lensing/lensing_plots.py @@ -83,10 +83,19 @@ def plot_maps(images, true_tile_cat, est_tile_cat, figsize=None, current_epoch=0 figsize = (20, 20) fig, axes = plt.subplots(nrows=num_images, ncols=num_lensing_params, figsize=figsize) - true_shear = true_tile_cat["shear"] - est_shear = est_tile_cat["shear"] - true_convergence = true_tile_cat["convergence"] - est_convergence = est_tile_cat["convergence"] + true_shear1 = true_tile_cat["shear_1"] + true_shear2 = true_tile_cat["shear_2"] + pred_shear1 = est_tile_cat["shear_1"] + pred_shear2 = est_tile_cat["shear_2"] + true_shear = torch.cat((true_shear1, true_shear2), dim=-1) + est_shear = torch.cat((pred_shear1, pred_shear2), dim=-1) + + if "convergence" not in est_tile_cat: + true_convergence = torch.zeros_like(true_shear1) + est_convergence = torch.zeros_like(true_convergence) + else: + true_convergence = true_tile_cat["convergence"] + est_convergence = est_tile_cat["convergence"] for img_id in img_ids: shear1_vmin = torch.min(true_shear[img_id].squeeze()[:, :, 0])