Skip to content

Commit 4037271

Browse files
authored
Merge pull request #61 from Kevin2/511_pixel
Add functionalities for 511 pixel pictures
2 parents 62b6b68 + e5e4500 commit 4037271

File tree

8 files changed

+79
-26
lines changed

8 files changed

+79
-26
lines changed

Diff for: dl_framework/architectures.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def forward(self, x):
692692

693693

694694
class filter_deep_amp(nn.Module):
695-
def __init__(self):
695+
def __init__(self, img_size):
696696
super().__init__()
697697
self.conv1_amp = nn.Sequential(
698698
*conv_amp(1, 4, (23, 23), 1, 11, 1)
@@ -704,7 +704,7 @@ def __init__(self):
704704
*conv_amp(8, 12, (17, 17), 1, 8, 1)
705705
)
706706
self.conv_con1_amp = nn.Sequential(
707-
LocallyConnected2d(12, 1, 63, 1, stride=1, bias=False),
707+
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
708708
nn.BatchNorm2d(1),
709709
nn.ReLU(),
710710
)
@@ -722,7 +722,7 @@ def __init__(self):
722722
*conv_amp(12, 16, (3, 3), 1, 1, 1)
723723
)
724724
self.conv_con2_amp = nn.Sequential(
725-
LocallyConnected2d(16, 1, 63, 1, stride=1, bias=False),
725+
LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False),
726726
nn.BatchNorm2d(1),
727727
nn.ReLU(),
728728
)
@@ -737,7 +737,7 @@ def __init__(self):
737737
*conv_amp(8, 12, (3, 3), 1, 2, 2)
738738
)
739739
self.conv_con3_amp = nn.Sequential(
740-
LocallyConnected2d(12, 1, 63, 1, stride=1, bias=False),
740+
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
741741
nn.BatchNorm2d(1),
742742
nn.ReLU(),
743743
)
@@ -776,13 +776,15 @@ def forward(self, x):
776776

777777
amp = self.conv_con3_amp(amp)
778778

779-
amp = amp + inp[:, 0].unsqueeze(1)
780-
x0 = self.symmetry_real(amp).reshape(-1, 1, 63, 63)
779+
inp_amp = inp[:, 0].unsqueeze(1)
780+
x0 = self.symmetry_real(amp).reshape(-1, 1, amp.shape[2], amp.shape[2])
781+
x0[inp_amp != 0] = inp_amp[inp_amp != 0]
782+
781783
return x0
782784

783785

784786
class filter_deep_phase(nn.Module):
785-
def __init__(self):
787+
def __init__(self, img_size):
786788
super().__init__()
787789
self.conv1_phase = nn.Sequential(
788790
*conv_phase(1, 4, (23, 23), 1, 11, 1, add=-2.1415)
@@ -794,7 +796,7 @@ def __init__(self):
794796
*conv_phase(8, 12, (17, 17), 1, 8, 1, add=-2.1415)
795797
)
796798
self.conv_con1_phase = nn.Sequential(
797-
LocallyConnected2d(12, 1, 63, 1, stride=1, bias=False),
799+
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
798800
nn.BatchNorm2d(1),
799801
GeneralELU(-2.1415),
800802
)
@@ -812,7 +814,7 @@ def __init__(self):
812814
*conv_phase(12, 16, (3, 3), 1, 1, 1, add=-2.1415)
813815
)
814816
self.conv_con2_phase = nn.Sequential(
815-
LocallyConnected2d(16, 1, 63, 1, stride=1, bias=False),
817+
LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False),
816818
nn.BatchNorm2d(1),
817819
GeneralELU(-2.1415),
818820
)
@@ -827,7 +829,7 @@ def __init__(self):
827829
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=-2.1415)
828830
)
829831
self.conv_con3_phase = nn.Sequential(
830-
LocallyConnected2d(12, 1, 63, 1, stride=1, bias=False),
832+
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
831833
nn.BatchNorm2d(1),
832834
GeneralELU(-2.1415),
833835
)
@@ -866,6 +868,8 @@ def forward(self, x):
866868

867869
phase = self.conv_con3_phase(phase)
868870

869-
phase = phase + inp[:, 1].unsqueeze(1)
870-
x1 = self.symmetry_imag(phase).reshape(-1, 1, 63, 63)
871+
inp_phase = inp[:, 1].unsqueeze(1)
872+
873+
x1 = self.symmetry_imag(phase).reshape(-1, 1, phase.shape[2], phase.shape[2])
874+
x1[inp_phase != 0] = inp_phase[inp_phase != 0]
871875
return x1

Diff for: dl_framework/data.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def __getitem__(self, i):
6464
return x, y
6565

6666
def open_bundle(self, bundle_path, var):
67-
bundle = h5py.File(bundle_path, "r")
67+
# distinguish between compressed (npz) or not compressed (h5)
68+
if re.search(".npz", str(bundle_path)):
69+
bundle = np.load(bundle_path, mmap_mode="r")
70+
else:
71+
bundle = h5py.File(bundle_path, "r")
6872
data = bundle[var]
6973
return data
7074

@@ -77,14 +81,21 @@ def open_image(self, var, i):
7781
bundle = indices // self.num_img
7882
image = indices - bundle * self.num_img
7983
bundle_unique = torch.unique(bundle)
80-
bundle_paths = [
81-
h5py.File(self.bundles[bundle], "r") for bundle in bundle_unique
82-
]
84+
# distinguish between compressed (npz) or not compressed (h5)
85+
if re.search(".npz", str(self.bundles[bundle[0]])):
86+
bundle_paths = [
87+
np.load(self.bundles[bundle], mmap_mode='r') for bundle in bundle_unique
88+
]
89+
else:
90+
bundle_paths = [
91+
h5py.File(self.bundles[bundle], 'r') for bundle in bundle_unique
92+
]
93+
bundle_paths_str = list(map(str, bundle_paths))
8394
data = torch.tensor(
8495
[
8596
bund[var][img]
86-
for bund in bundle_paths
87-
for img in image[bundle == bundle_unique[bundle_paths.index(bund)]]
97+
for bund, bund_str in zip(bundle_paths, bundle_paths_str)
98+
for img in image[bundle == bundle_unique[bundle_paths_str.index(bund_str)]]
8899
]
89100
)
90101
if var == "x" or self.tar_fourier is True:
@@ -193,6 +204,16 @@ def open_fft_pair(path):
193204
return bundle_x, bundle_y
194205

195206

207+
def open_fft_pair_npz(path):
208+
"""
209+
open fft_pairs for files saved in .npz format
210+
"""
211+
f = np.load(path)
212+
bundle_x = np.array(f["x"])
213+
bundle_y = np.array(f["y"])
214+
return bundle_x, bundle_y
215+
216+
196217
def mean_and_std(array):
197218
return array.mean(), array.std()
198219

Diff for: gaussian_sources/Makefile

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ ${data}/fft_samp_train0.h5:
2626
-amp_phase ${amp_phase} \
2727
-noise ${noise} \
2828
-preview ${preview} \
29-
-specific_mask True \
29+
-specific_mask False \
30+
-compressed True\
3031
-lon -80 \
3132
-lat 50 \
3233
-steps 50

Diff for: gaussian_sources/calculate_normalization.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import click
2-
from dl_framework.data import open_fft_pair, get_bundles, mean_and_std
2+
from dl_framework.data import (
3+
open_fft_pair_npz,
4+
open_fft_pair,
5+
get_bundles,
6+
mean_and_std,
7+
)
38
import pandas as pd
49
import numpy as np
510
import re
@@ -21,7 +26,11 @@ def main(data_path, out_path):
2126
stds_imag = np.array([])
2227

2328
for path in tqdm(bundle_paths):
24-
x, _ = open_fft_pair(path)
29+
# distinguish between compressed (.npz) and not compressed (.h5) files
30+
if re.search(".npz", str(path)):
31+
x, _ = open_fft_pair_npz(path)
32+
else:
33+
x, _ = open_fft_pair(path)
2534
x_amp, x_imag = np.double(x[:, 0]), np.double(x[:, 1])
2635
mean_amp, std_amp = mean_and_std(x_amp)
2736
mean_imag, std_imag = mean_and_std(x_imag)
@@ -38,8 +47,8 @@ def main(data_path, out_path):
3847
d = {
3948
"train_mean_c0": [mean_amp],
4049
"train_std_c0": [std_amp],
41-
'train_mean_c1': [mean_imag],
42-
'train_std_c1': [std_imag]
50+
"train_mean_c1": [mean_imag],
51+
"train_std_c1": [std_imag],
4352
}
4453

4554
df = pd.DataFrame(data=d)

Diff for: gaussian_sources/create_fft_pairs.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import click
2+
import os
23
import numpy as np
34
from tqdm import tqdm
45
from dl_framework.data import (
@@ -11,6 +12,7 @@
1112
from simulations.uv_simulations import sample_freqs
1213
from simulations.gaussian_simulations import add_noise
1314
import re
15+
from numpy import savez_compressed
1416

1517

1618
@click.command()
@@ -20,6 +22,7 @@
2022
@click.option("-amp_phase", type=bool, required=True)
2123
@click.option("-fourier", type=bool, required=True)
2224
@click.option("-specific_mask", type=bool)
25+
@click.option("-compressed", type=bool)
2326
@click.option("-lon", type=float, required=False)
2427
@click.option("-lat", type=float, required=False)
2528
@click.option("-steps", type=float, required=False)
@@ -31,6 +34,7 @@ def main(
3134
antenna_config_path,
3235
amp_phase,
3336
fourier,
37+
compressed=False,
3438
specific_mask=False,
3539
lon=None,
3640
lat=None,
@@ -69,6 +73,11 @@ def main(
6973
if amp_phase:
7074
amp, phase = split_amp_phase(bundle_fft)
7175
amp = (np.log10(amp + 1e-10) / 10) + 1
76+
77+
# Test new masking for 511 Pixel pictures
78+
if amp.shape[1] == 511:
79+
mask = amp > 0.1
80+
phase[~mask] = 0
7281
bundle_fft = np.stack((amp, phase), axis=1)
7382
else:
7483
real, imag = split_real_imag(bundle_fft)
@@ -100,7 +109,11 @@ def main(
100109
)
101110
out = out_path + path.name.split("_")[-1]
102111
if fourier:
103-
save_fft_pair(out, bundle_samp, bundle_fft)
112+
if compressed:
113+
savez_compressed(out, x=bundle_samp, y=bundle_fft)
114+
os.remove(path)
115+
else:
116+
save_fft_pair(out, bundle_samp, bundle_fft)
104117
else:
105118
save_fft_pair(out, bundle_samp, images)
106119

Diff for: gaussian_sources/find_lr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(
8888
img_size = train_ds[0][0][0].shape[1]
8989
# Define model
9090
arch_name = arch
91-
if arch == "filter_deep":
91+
if arch == "filter_deep" or arch == "filter_deep_amp" or arch == "filter_deep_phase":
9292
arch = getattr(architecture, arch)(img_size)
9393
else:
9494
arch = getattr(architecture, arch)()

Diff for: gaussian_sources/inspection.py

+5
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,11 @@ def plot_difference(i, img_pred, img_truth, sensitivity, out_path):
317317
num_three = 60
318318
num_two = 50
319319
num_four = 40
320+
elif img_size == 511:
321+
# work in progress, these are dummy values for compilating
322+
num_three = 60
323+
num_two = 50
324+
num_four = 40
320325
num = [num_three, num_two, num_four]
321326
dr_truth, mode = compute_dr(i, img_truth, sensitivity, num)
322327
dr_pred = compute_dr_pred(img_pred, mode, num)

Diff for: gaussian_sources/train_cnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(
8383

8484
img_size = train_ds[0][0][0].shape[1]
8585
# Define model
86-
if arch == "filter_deep":
86+
if arch == "filter_deep" or arch == "filter_deep_amp" or arch == "filter_deep_phase":
8787
arch = getattr(architecture, arch)(img_size)
8888
else:
8989
arch = getattr(architecture, arch)()

0 commit comments

Comments
 (0)