-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added changes for plotting extraction to case_studies, plotcollection class, updated paths and configs * added preliminary dc2 data exploration (week 1 and 2) * attempted refactor of plots to use metrics instead of plotcollection, modified yaml files to support change, added null check before plotting in encoder * updated base_config to reflect correct path to plots * Refactor lensing plots * updated encoder and plots to reflect req changes * updated name of sample_image_renders to be more descriptive re function * renamed plots to sample_image_renders and updated yaml to reflect * updated encoder and sample_images to add support for weak_lensing plots (minor changes), refactored weak_lensing plots (major change) * update plot_detections path in region_encoder.py * Delete case_studies/dc2/dc2_shear_conv.ipynb (refactor and move to weak_lensing later on) * Fix list/dict error in lensing config * updated lensing config to match modifications * initial commit of unfinished lensing map EDA nb * updated weak lensing EDA notebook * updated sample_image_renders to match changes in PR#1007 * Move DC2 lensing maps to notebooks dir * Notebook with two-point corr for DC2 * Updated two-point notebook * added dc2 datagen notebook for weak lensing * Use DC2 redshift information in two-point notebook * partial changes and restored full dc2, up to date with master; testing encoder and lensing_dc2 * dc2 for weak lensing naive architecture with subclassed dc2, convnet, encoder, modified metrics, have to make some changes to pass commit tests * updated convnet and encoder to match smaller image size * attempted fix of nans in backward pass * fixed lensing dc2 to remove split files, made corresponding changes in cached dataset, updated lensing convnet to restore to old features net and fixed lensing plots * save plots changes * updated to 1 downsampling * added scale matching for plots * updated lensing_config * updated lensing config * images without background * Remove unnecessary .view in to_full_catalog * Update base_config.yaml * Update base_config.yaml * pylint fix for lensing metrics * Fix shear/convergence shapes in lensing MSE * Add image-level means as baseline estimator for shear and convergence * fixed total in lensing metrics to pass check * updated encoder and vardist to roll back tilecatalog changes, replace with basetilecatalog * removed duplicate methods in lensing encoder * restored base_config from master * updated lensing_dc2 to factor in changes to dc2 * added simple null normalizer to image normalizer (will remove in subsequent commit after fixing asinh), updated lensing config paths, and modified masking operation in lensing_dc2 * updated files to reflect recent changes and PR comments * removed try/catch * updated lensing_dc2 to match recent changes to dc2 --------- Co-authored-by: Tim White <[email protected]> Co-authored-by: Shreyas Chandrashekaran <[email protected]>
- Loading branch information
1 parent
a77531b
commit e857c6f
Showing
11 changed files
with
2,283 additions
and
548 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from torch import nn | ||
|
||
from bliss.encoder.convnet_layers import C3, ConvBlock, Detect | ||
|
||
|
||
class WeakLensingFeaturesNet(nn.Module): | ||
def __init__(self, n_bands, ch_per_band, num_features, tile_slen): | ||
super().__init__() | ||
|
||
nch_hidden = 64 | ||
self.preprocess3d = nn.Sequential( | ||
nn.Conv3d(n_bands, nch_hidden, [ch_per_band, 5, 5], padding=[0, 2, 2]), | ||
nn.GroupNorm( | ||
num_groups=32, num_channels=nch_hidden | ||
), # sqrt of num channels, get rid of it, even shallower | ||
nn.SiLU(), | ||
) | ||
|
||
# TODO: adaptive downsample | ||
self.n_downsample = 1 | ||
|
||
module_list = [] | ||
|
||
for _ in range(self.n_downsample): | ||
module_list.append( | ||
ConvBlock(nch_hidden, 2 * nch_hidden, kernel_size=5, stride=2, padding=2) | ||
) | ||
nch_hidden *= 2 | ||
|
||
module_list.extend( | ||
[ | ||
ConvBlock(nch_hidden, 64, kernel_size=5, padding=2), | ||
nn.Sequential(*[ConvBlock(64, 64, kernel_size=5, padding=2) for _ in range(1)]), | ||
ConvBlock(64, 128, stride=2), | ||
nn.Sequential(*[ConvBlock(128, 128) for _ in range(1)]), | ||
ConvBlock(128, num_features, stride=1), | ||
] | ||
) # 4 | ||
|
||
self.net = nn.ModuleList(module_list) | ||
|
||
def forward(self, x): | ||
x = self.preprocess3d(x).squeeze(2) | ||
for _i, m in enumerate(self.net): | ||
x = m(x) | ||
|
||
return x | ||
|
||
|
||
class WeakLensingCatalogNet(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
|
||
net_layers = [ | ||
C3(in_channels, 256, n=1, shortcut=True), # 0 | ||
ConvBlock(256, 512, stride=2), | ||
C3(512, 256, n=1, shortcut=True), # true shortcut for skip connection | ||
ConvBlock( | ||
in_channels=256, out_channels=256, kernel_size=3, stride=8, padding=1 | ||
), # (1, 256, 128, 128) | ||
ConvBlock( | ||
in_channels=256, out_channels=256, kernel_size=3, stride=4, padding=1 | ||
), # (1, 256, 8, 8) | ||
Detect(256, out_channels), | ||
] | ||
self.net = nn.ModuleList(net_layers) | ||
|
||
def forward(self, x): | ||
for _i, m in enumerate(self.net): | ||
x = m(x) | ||
return x |
Oops, something went wrong.