Skip to content

Commit

Permalink
high-resolution images & a 2d discrete variational factor (#1068)
Browse files Browse the repository at this point in the history
* null sample renders bug

* fix annoying data_source warning

* aug5_discretizedbox_quarterpixels

* not much, some m2 notebook stuff

* sdss case study

* revert convnets to master

* fixed some merge discrepencies

* bring back former to_tile_catalog

* self.double_sample_prob and nll hack

* sdss demo needs retrain with one band fluxes

* tiny changes to m2 case study

* revert problematic changes

* fix bin_cutoffs

* logsumexp for double detections

* back to 32-true

* lots of 1x1 kernels

* updated sample for logsumexp double detect

* wider context net

* removed incorrect masking

* gate new_est_cat2 during sampling

* simplify color context

* count_net

* countnet refactor

* richer color history

* no groupnorm in heads

* minimalist_history

* deeper

* 1x1 for real

* centered locs for colors

* restore flux history to local context

* add n_sources to color history

* only nsources color

* spatial context for color

* spatial countnet too

* locs only local context

* embedding

* restore fluxes to local context

* groupnorm4all less color spatial

* extra spatial for countnet and color

* shallower

* earlier first skip connection

* null normalizer

* no groupnorm for localnet or detectionnet

* restore normalizers

* flux in color context

* recovered from merge?

* use multiprocessing for generate

* add mask patterns

* pylint

* flake8

* remove simulator

* forgot to remove line

* wrong path

* simplify prior

* restored double downsample

* remove fixtures; fix test paths

* revert log10 division change

* tiles_to_crop

* smaller images

* new base_config encoder

* notebooks running again
  • Loading branch information
jeff-regier authored Sep 3, 2024
1 parent bd96328 commit 84ca732
Show file tree
Hide file tree
Showing 41 changed files with 881 additions and 1,282 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ disable=too-many-ancestors,
unused-argument,

# flake8 recommends against f-strings
consider-using-f-string,
logging-fstring-interpolation,

# flake8 already checks for lambda expressions, which are OK at times
unnecessary-lambda-assignment,
Expand Down
1 change: 1 addition & 0 deletions bliss/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __call__(self, datum_in):

class ChunkingSampler(Sampler):
def __init__(self, dataset: Dataset) -> None:
super().__init__()
# please don't pass dataset to the following __init__()
# according to https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
# the parameter `data_source` has been deprecated
Expand Down
46 changes: 40 additions & 6 deletions bliss/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@ def symmetric_crop(self, tiles_to_crop):
[tiles_to_crop, self.n_tiles_w - tiles_to_crop],
)

def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
assert box_origin[0] + box_len < self.height, "invalid box"
assert box_origin[1] + box_len < self.width, "invalid box"

box_origin_tensor = box_origin.view(1, 1, 2).to(device=self.device)
box_end_tensor = (box_origin + box_len).view(1, 1, 2).to(device=self.device)

plocs_mask = torch.all(
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
_, index_order = plocs_full_mask_indexes[:, 0].sort(stable=True)
plocs_full_mask_sorted_indexes = plocs_full_mask_indexes[index_order.tolist(), :]

d = {}
new_max_sources = plocs_mask.sum(dim=1).max()
for k, v in self.items():
if k == "n_sources":
d[k] = plocs_mask.sum(dim=1)
else:
d[k] = v[
plocs_full_mask_sorted_indexes[:, 0].tolist(),
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor

return FullCatalog(box_len, box_len, d)


class TileCatalog(BaseTileCatalog):
galaxy_params = [
Expand Down Expand Up @@ -335,9 +367,8 @@ def union(self, other, disjoint=False):
ns11 = rearrange(self["n_sources"], "b ht wt -> b ht wt 1 1")
for k, v in self.items():
if k == "n_sources":
assert not disjoint or ((v == 0) | (other[k] == 0)).all()
d[k] = v + other[k]
if disjoint:
assert d[k].max() <= 1
else:
if disjoint:
d1 = v
Expand Down Expand Up @@ -734,9 +765,7 @@ def to_tile_catalog(

return TileCatalog(tile_params)

# pylint: enable=R0912,R0915

def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float, exclude_box=False):
assert box_origin[0] + box_len <= self.height, "invalid box"
assert box_origin[1] + box_len <= self.width, "invalid box"

Expand All @@ -747,6 +776,9 @@ def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
)

if exclude_box:
plocs_mask = ~plocs_mask

plocs_mask_indexes = plocs_mask.nonzero()
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
Expand All @@ -764,6 +796,8 @@ def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
plocs_full_mask_sorted_indexes[:, 1].tolist(),
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]

d["plocs"] -= box_origin_tensor
if exclude_box:
return FullCatalog(self.height, self.width, d)

d["plocs"] -= box_origin_tensor
return FullCatalog(box_len, box_len, d)
90 changes: 52 additions & 38 deletions bliss/conf/base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,30 @@ paths:
cached_data: /data/scratch/regier/sdss_like
output: ${oc.env:HOME}/bliss_output

# this prior is sdss-like; the flux parameters were fit using SDSS catalogs
# this prior is sdss-like; the parameters were fit using SDSS catalogs
prior:
_target_: bliss.simulator.prior.CatalogPrior
survey_bands: [u, g, r, i, z] # SDSS available band filters
reference_band: 2 # SDSS r-band
star_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 20
n_tiles_w: 20
batch_size: 64
star_color_model_path: ${paths.sdss}/color_models/star_gmm_nmgy.pkl
gal_color_model_path: ${paths.sdss}/color_models/gal_gmm_nmgy.pkl
n_tiles_h: 68 # cropping 2 tiles from each side
n_tiles_w: 68 # cropping 2 tiles from each side
batch_size: 8
max_sources: 1
mean_sources: 0.01 # 0.0025 is more realistic for SDSS but training takes more iterations
mean_sources: 0.0025
min_sources: 0
prob_galaxy: 0.5144
star_flux_exponent: 0.4689157382430609
star_flux_truncation: 613313.768995269
star_flux_loc: -0.5534648001193676
star_flux_scale: 1.1846035501201129
galaxy_flux_exponent: 1.5609458661807678
galaxy_flux_truncation: 28790.449063519092
galaxy_flux_loc: -3.29383532288203
galaxy_flux_scale: 3.924799999613338
star_flux:
exponent: 0.4689157382430609
truncation: 613313.768995269
loc: -0.5534648001193676
scale: 1.1846035501201129
galaxy_flux:
exponent: 1.5609458661807678
truncation: 28790.449063519092
loc: -3.29383532288203
scale: 3.924799999613338
galaxy_a_concentration: 0.39330758068481686
galaxy_a_loc: 0.8371888967872619
galaxy_a_scale: 4.432725319432478
Expand All @@ -51,20 +53,11 @@ decoder:
with_dither: true
with_noise: true

simulator:
_target_: bliss.simulator.simulated_dataset.SimulatedDataset
prior: ${prior}
decoder: ${decoder}
n_batches: 128
num_workers: 32
valid_n_batches: 10 # 256
fix_validation_set: true

cached_simulator:
_target_: bliss.cached_dataset.CachedSimulatedDataModule
batch_size: 64
batch_size: 16
splits: 0:80/80:90/90:100 # train/val/test splits as percent ranges
num_workers: 8
num_workers: 4
cached_data_path: ${paths.cached_data}
train_transforms:
- _target_: bliss.data_augmentation.RotateFlipTransform
Expand Down Expand Up @@ -140,23 +133,42 @@ variational_factors:
nll_gating:
_target_: bliss.encoder.variational_dist.GalaxyGating

# these are in nanomaggies
sdss_mag_zero_point: 1e9
sdss_flux_cutoffs:
- 1.4928
- 1.9055
- 2.7542
- 3.9811
- 5.7544
- 8.3176
- 12.0227
- 17.3780
- 25.1189

metrics:
detection_performance:
_target_: bliss.encoder.metrics.DetectionPerformance
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2
source_type_accuracy:
_target_: bliss.encoder.metrics.SourceTypeAccuracy
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2
flux_error:
_target_: bliss.encoder.metrics.FluxError
survey_bands: ${encoder.survey_bands}
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
mag_zero_point: 3631e9 # for DC2
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
mag_zero_point: ${sdss_mag_zero_point}
report_bin_unit: mag
exclude_last_bin: true
ref_band: 2

image_normalizers:
psf:
Expand All @@ -173,7 +185,7 @@ encoder:
_target_: bliss.encoder.encoder.Encoder
survey_bands: [u, g, r, i, z]
reference_band: 2 # SDSS r-band
tile_slen: ${simulator.decoder.tile_slen}
tile_slen: ${decoder.tile_slen}
optimizer_params:
lr: 1e-3
scheduler_params:
Expand Down Expand Up @@ -201,7 +213,7 @@ encoder:
frequency: 1
restrict_batch: 0
tiles_to_crop: 0
tile_slen: ${simulator.decoder.tile_slen}
tile_slen: ${decoder.tile_slen}
use_double_detect: false
use_checkerboard: false
n_sampler_colors: 4
Expand Down Expand Up @@ -278,11 +290,13 @@ surveys:
mode: train

generate:
n_image_files: 64
n_batches_per_file: 16
simulator: ${simulator}
prior: ${prior}
decoder: ${decoder}
tiles_to_crop: 2
n_image_files: 512
n_batches_per_file: 32 # multiply by prior.batch_size to get total number of images
n_processes: 16 # using more isn't necessarily faster
cached_data_path: ${paths.cached_data}
file_prefix: dataset
store_full_catalog: false

train:
Expand Down
27 changes: 17 additions & 10 deletions bliss/encoder/convnet_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@


class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, gn=True):
super().__init__()
assert kernel_size % 2 == 1, "kernel size must be odd"
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
# seems to work about as well as BatchNorm2d
self.norm = nn.GroupNorm(out_channels // 8, out_channels)
n_groups = out_channels // 8
use_gn = gn and n_groups >= 16
self.norm = nn.GroupNorm(n_groups, out_channels) if use_gn else nn.Identity()
self.activation = nn.SiLU(inplace=True)

def forward(self, x):
Expand All @@ -27,11 +31,12 @@ def forward(self, x):


class Bottleneck(nn.Module):
def __init__(self, c1, c2, shortcut=True, e=0.5):
def __init__(self, c1, c2, shortcut=True, e=0.5, gn=True, spatial=True):
super().__init__()
ch = int(c2 * e)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv2 = ConvBlock(ch, c2, kernel_size=3, padding=1, stride=1)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
ks = 3 if spatial else 1
self.cv2 = ConvBlock(ch, c2, kernel_size=ks, stride=1, gn=gn)
self.add = shortcut and c1 == c2

def forward(self, x):
Expand All @@ -40,13 +45,15 @@ def forward(self, x):


class C3(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5):
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, gn=True, spatial=True):
super().__init__()
ch = int(c2 * e)
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv2 = ConvBlock(c1, ch, kernel_size=1, padding=0)
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, padding=0)
self.m = nn.Sequential(*(Bottleneck(ch, ch, shortcut, e=1.0) for _ in range(n)))
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
self.cv2 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, gn=gn)
self.m = nn.Sequential(
*(Bottleneck(ch, ch, shortcut, e=1.0, spatial=spatial) for _ in range(n)),
)

def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
Loading

0 comments on commit 84ca732

Please sign in to comment.