Skip to content

Commit 84ca732

Browse files
authored
high-resolution images & a 2d discrete variational factor (#1068)
* null sample renders bug * fix annoying data_source warning * aug5_discretizedbox_quarterpixels * not much, some m2 notebook stuff * sdss case study * revert convnets to master * fixed some merge discrepencies * bring back former to_tile_catalog * self.double_sample_prob and nll hack * sdss demo needs retrain with one band fluxes * tiny changes to m2 case study * revert problematic changes * fix bin_cutoffs * logsumexp for double detections * back to 32-true * lots of 1x1 kernels * updated sample for logsumexp double detect * wider context net * removed incorrect masking * gate new_est_cat2 during sampling * simplify color context * count_net * countnet refactor * richer color history * no groupnorm in heads * minimalist_history * deeper * 1x1 for real * centered locs for colors * restore flux history to local context * add n_sources to color history * only nsources color * spatial context for color * spatial countnet too * locs only local context * embedding * restore fluxes to local context * groupnorm4all less color spatial * extra spatial for countnet and color * shallower * earlier first skip connection * null normalizer * no groupnorm for localnet or detectionnet * restore normalizers * flux in color context * recovered from merge? * use multiprocessing for generate * add mask patterns * pylint * flake8 * remove simulator * forgot to remove line * wrong path * simplify prior * restored double downsample * remove fixtures; fix test paths * revert log10 division change * tiles_to_crop * smaller images * new base_config encoder * notebooks running again
1 parent bd96328 commit 84ca732

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+881
-1282
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ disable=too-many-ancestors,
5151
unused-argument,
5252

5353
# flake8 recommends against f-strings
54-
consider-using-f-string,
54+
logging-fstring-interpolation,
5555

5656
# flake8 already checks for lambda expressions, which are OK at times
5757
unnecessary-lambda-assignment,

bliss/cached_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __call__(self, datum_in):
9393

9494
class ChunkingSampler(Sampler):
9595
def __init__(self, dataset: Dataset) -> None:
96+
super().__init__()
9697
# please don't pass dataset to the following __init__()
9798
# according to https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler
9899
# the parameter `data_source` has been deprecated

bliss/catalog.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,38 @@ def symmetric_crop(self, tiles_to_crop):
8282
[tiles_to_crop, self.n_tiles_w - tiles_to_crop],
8383
)
8484

85+
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
86+
assert box_origin[0] + box_len < self.height, "invalid box"
87+
assert box_origin[1] + box_len < self.width, "invalid box"
88+
89+
box_origin_tensor = box_origin.view(1, 1, 2).to(device=self.device)
90+
box_end_tensor = (box_origin + box_len).view(1, 1, 2).to(device=self.device)
91+
92+
plocs_mask = torch.all(
93+
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
94+
)
95+
96+
plocs_mask_indexes = plocs_mask.nonzero()
97+
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
98+
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
99+
_, index_order = plocs_full_mask_indexes[:, 0].sort(stable=True)
100+
plocs_full_mask_sorted_indexes = plocs_full_mask_indexes[index_order.tolist(), :]
101+
102+
d = {}
103+
new_max_sources = plocs_mask.sum(dim=1).max()
104+
for k, v in self.items():
105+
if k == "n_sources":
106+
d[k] = plocs_mask.sum(dim=1)
107+
else:
108+
d[k] = v[
109+
plocs_full_mask_sorted_indexes[:, 0].tolist(),
110+
plocs_full_mask_sorted_indexes[:, 1].tolist(),
111+
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]
112+
113+
d["plocs"] -= box_origin_tensor
114+
115+
return FullCatalog(box_len, box_len, d)
116+
85117

86118
class TileCatalog(BaseTileCatalog):
87119
galaxy_params = [
@@ -335,9 +367,8 @@ def union(self, other, disjoint=False):
335367
ns11 = rearrange(self["n_sources"], "b ht wt -> b ht wt 1 1")
336368
for k, v in self.items():
337369
if k == "n_sources":
370+
assert not disjoint or ((v == 0) | (other[k] == 0)).all()
338371
d[k] = v + other[k]
339-
if disjoint:
340-
assert d[k].max() <= 1
341372
else:
342373
if disjoint:
343374
d1 = v
@@ -734,9 +765,7 @@ def to_tile_catalog(
734765

735766
return TileCatalog(tile_params)
736767

737-
# pylint: enable=R0912,R0915
738-
739-
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
768+
def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float, exclude_box=False):
740769
assert box_origin[0] + box_len <= self.height, "invalid box"
741770
assert box_origin[1] + box_len <= self.width, "invalid box"
742771

@@ -747,6 +776,9 @@ def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
747776
(self["plocs"] < box_end_tensor) & (self["plocs"] > box_origin_tensor), dim=2
748777
)
749778

779+
if exclude_box:
780+
plocs_mask = ~plocs_mask
781+
750782
plocs_mask_indexes = plocs_mask.nonzero()
751783
plocs_inverse_mask_indexes = (~plocs_mask).nonzero()
752784
plocs_full_mask_indexes = torch.cat((plocs_mask_indexes, plocs_inverse_mask_indexes), dim=0)
@@ -764,6 +796,8 @@ def filter_by_ploc_box(self, box_origin: torch.Tensor, box_len: float):
764796
plocs_full_mask_sorted_indexes[:, 1].tolist(),
765797
].view(-1, self.max_sources, v.shape[-1])[:, :new_max_sources, :]
766798

767-
d["plocs"] -= box_origin_tensor
799+
if exclude_box:
800+
return FullCatalog(self.height, self.width, d)
768801

802+
d["plocs"] -= box_origin_tensor
769803
return FullCatalog(box_len, box_len, d)

bliss/conf/base_config.yaml

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,30 @@ paths:
1717
cached_data: /data/scratch/regier/sdss_like
1818
output: ${oc.env:HOME}/bliss_output
1919

20-
# this prior is sdss-like; the flux parameters were fit using SDSS catalogs
20+
# this prior is sdss-like; the parameters were fit using SDSS catalogs
2121
prior:
2222
_target_: bliss.simulator.prior.CatalogPrior
2323
survey_bands: [u, g, r, i, z] # SDSS available band filters
2424
reference_band: 2 # SDSS r-band
25-
star_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/star_gmm_nmgy.pkl
26-
gal_color_model_path: ${simulator.decoder.survey.dir_path}/color_models/gal_gmm_nmgy.pkl
27-
n_tiles_h: 20
28-
n_tiles_w: 20
29-
batch_size: 64
25+
star_color_model_path: ${paths.sdss}/color_models/star_gmm_nmgy.pkl
26+
gal_color_model_path: ${paths.sdss}/color_models/gal_gmm_nmgy.pkl
27+
n_tiles_h: 68 # cropping 2 tiles from each side
28+
n_tiles_w: 68 # cropping 2 tiles from each side
29+
batch_size: 8
3030
max_sources: 1
31-
mean_sources: 0.01 # 0.0025 is more realistic for SDSS but training takes more iterations
31+
mean_sources: 0.0025
3232
min_sources: 0
3333
prob_galaxy: 0.5144
34-
star_flux_exponent: 0.4689157382430609
35-
star_flux_truncation: 613313.768995269
36-
star_flux_loc: -0.5534648001193676
37-
star_flux_scale: 1.1846035501201129
38-
galaxy_flux_exponent: 1.5609458661807678
39-
galaxy_flux_truncation: 28790.449063519092
40-
galaxy_flux_loc: -3.29383532288203
41-
galaxy_flux_scale: 3.924799999613338
34+
star_flux:
35+
exponent: 0.4689157382430609
36+
truncation: 613313.768995269
37+
loc: -0.5534648001193676
38+
scale: 1.1846035501201129
39+
galaxy_flux:
40+
exponent: 1.5609458661807678
41+
truncation: 28790.449063519092
42+
loc: -3.29383532288203
43+
scale: 3.924799999613338
4244
galaxy_a_concentration: 0.39330758068481686
4345
galaxy_a_loc: 0.8371888967872619
4446
galaxy_a_scale: 4.432725319432478
@@ -51,20 +53,11 @@ decoder:
5153
with_dither: true
5254
with_noise: true
5355

54-
simulator:
55-
_target_: bliss.simulator.simulated_dataset.SimulatedDataset
56-
prior: ${prior}
57-
decoder: ${decoder}
58-
n_batches: 128
59-
num_workers: 32
60-
valid_n_batches: 10 # 256
61-
fix_validation_set: true
62-
6356
cached_simulator:
6457
_target_: bliss.cached_dataset.CachedSimulatedDataModule
65-
batch_size: 64
58+
batch_size: 16
6659
splits: 0:80/80:90/90:100 # train/val/test splits as percent ranges
67-
num_workers: 8
60+
num_workers: 4
6861
cached_data_path: ${paths.cached_data}
6962
train_transforms:
7063
- _target_: bliss.data_augmentation.RotateFlipTransform
@@ -140,23 +133,42 @@ variational_factors:
140133
nll_gating:
141134
_target_: bliss.encoder.variational_dist.GalaxyGating
142135

136+
# these are in nanomaggies
137+
sdss_mag_zero_point: 1e9
138+
sdss_flux_cutoffs:
139+
- 1.4928
140+
- 1.9055
141+
- 2.7542
142+
- 3.9811
143+
- 5.7544
144+
- 8.3176
145+
- 12.0227
146+
- 17.3780
147+
- 25.1189
148+
143149
metrics:
144150
detection_performance:
145151
_target_: bliss.encoder.metrics.DetectionPerformance
146-
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
147-
mag_zero_point: 3631e9 # for DC2
152+
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
153+
mag_zero_point: ${sdss_mag_zero_point}
148154
report_bin_unit: mag
155+
exclude_last_bin: true
156+
ref_band: 2
149157
source_type_accuracy:
150158
_target_: bliss.encoder.metrics.SourceTypeAccuracy
151-
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
152-
mag_zero_point: 3631e9 # for DC2
159+
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
160+
mag_zero_point: ${sdss_mag_zero_point}
153161
report_bin_unit: mag
162+
exclude_last_bin: true
163+
ref_band: 2
154164
flux_error:
155165
_target_: bliss.encoder.metrics.FluxError
156166
survey_bands: ${encoder.survey_bands}
157-
base_flux_bin_cutoffs: [200, 400, 600, 800, 1000]
158-
mag_zero_point: 3631e9 # for DC2
167+
base_flux_bin_cutoffs: ${sdss_flux_cutoffs}
168+
mag_zero_point: ${sdss_mag_zero_point}
159169
report_bin_unit: mag
170+
exclude_last_bin: true
171+
ref_band: 2
160172

161173
image_normalizers:
162174
psf:
@@ -173,7 +185,7 @@ encoder:
173185
_target_: bliss.encoder.encoder.Encoder
174186
survey_bands: [u, g, r, i, z]
175187
reference_band: 2 # SDSS r-band
176-
tile_slen: ${simulator.decoder.tile_slen}
188+
tile_slen: ${decoder.tile_slen}
177189
optimizer_params:
178190
lr: 1e-3
179191
scheduler_params:
@@ -201,7 +213,7 @@ encoder:
201213
frequency: 1
202214
restrict_batch: 0
203215
tiles_to_crop: 0
204-
tile_slen: ${simulator.decoder.tile_slen}
216+
tile_slen: ${decoder.tile_slen}
205217
use_double_detect: false
206218
use_checkerboard: false
207219
n_sampler_colors: 4
@@ -278,11 +290,13 @@ surveys:
278290
mode: train
279291

280292
generate:
281-
n_image_files: 64
282-
n_batches_per_file: 16
283-
simulator: ${simulator}
293+
prior: ${prior}
294+
decoder: ${decoder}
295+
tiles_to_crop: 2
296+
n_image_files: 512
297+
n_batches_per_file: 32 # multiply by prior.batch_size to get total number of images
298+
n_processes: 16 # using more isn't necessarily faster
284299
cached_data_path: ${paths.cached_data}
285-
file_prefix: dataset
286300
store_full_catalog: false
287301

288302
train:

bliss/encoder/convnet_layers.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66

77

88
class ConvBlock(nn.Module):
9-
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
9+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, gn=True):
1010
super().__init__()
11+
assert kernel_size % 2 == 1, "kernel size must be odd"
12+
padding = (kernel_size - 1) // 2
1113
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
1214
# seems to work about as well as BatchNorm2d
13-
self.norm = nn.GroupNorm(out_channels // 8, out_channels)
15+
n_groups = out_channels // 8
16+
use_gn = gn and n_groups >= 16
17+
self.norm = nn.GroupNorm(n_groups, out_channels) if use_gn else nn.Identity()
1418
self.activation = nn.SiLU(inplace=True)
1519

1620
def forward(self, x):
@@ -27,11 +31,12 @@ def forward(self, x):
2731

2832

2933
class Bottleneck(nn.Module):
30-
def __init__(self, c1, c2, shortcut=True, e=0.5):
34+
def __init__(self, c1, c2, shortcut=True, e=0.5, gn=True, spatial=True):
3135
super().__init__()
3236
ch = int(c2 * e)
33-
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
34-
self.cv2 = ConvBlock(ch, c2, kernel_size=3, padding=1, stride=1)
37+
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
38+
ks = 3 if spatial else 1
39+
self.cv2 = ConvBlock(ch, c2, kernel_size=ks, stride=1, gn=gn)
3540
self.add = shortcut and c1 == c2
3641

3742
def forward(self, x):
@@ -40,13 +45,15 @@ def forward(self, x):
4045

4146

4247
class C3(nn.Module):
43-
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5):
48+
def __init__(self, c1, c2, n=1, shortcut=True, e=0.5, gn=True, spatial=True):
4449
super().__init__()
4550
ch = int(c2 * e)
46-
self.cv1 = ConvBlock(c1, ch, kernel_size=1, padding=0)
47-
self.cv2 = ConvBlock(c1, ch, kernel_size=1, padding=0)
48-
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, padding=0)
49-
self.m = nn.Sequential(*(Bottleneck(ch, ch, shortcut, e=1.0) for _ in range(n)))
51+
self.cv1 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
52+
self.cv2 = ConvBlock(c1, ch, kernel_size=1, gn=gn)
53+
self.cv3 = ConvBlock(2 * ch, c2, kernel_size=1, gn=gn)
54+
self.m = nn.Sequential(
55+
*(Bottleneck(ch, ch, shortcut, e=1.0, spatial=spatial) for _ in range(n)),
56+
)
5057

5158
def forward(self, x):
5259
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))

0 commit comments

Comments
 (0)