Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3caddc7

Browse files
authoredOct 22, 2020
Merge pull request #62 from Kevin2/same_img_size
Use the same image size in filter_deep
2 parents 4037271 + 5756293 commit 3caddc7

30 files changed

+1121
-285
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ __pycache__/
1919
*.jpg
2020
*.pdf
2121
*.gif
22+
*.png
2223

2324
# make
2425
*_done

‎.travis.yml

+6
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@ language: python
22
python:
33
- "3.7"
44

5+
# directories:
6+
# - "/tmp/texlive"
7+
# - "$HOME/.texlive"
8+
59
before_install:
610
- sudo apt-get -y install libgeos-dev
711
- sudo apt-get -y install libproj-dev
12+
# - travis_wait 45 bash ./utilities/travis_setup.sh
13+
# - export PATH="/tmp/texlive/bin/x86_64-linux:$PATH"
814

915
install:
1016
- pip install .

‎dl_framework/architectures.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -557,10 +557,10 @@ def __init__(self, img_size):
557557
)
558558

559559
self.conv4_amp = nn.Sequential(
560-
*conv_amp(1, 4, (5, 5), 1, 3, 2)
560+
*conv_amp(1, 4, (5, 5), 1, 4, 2)
561561
)
562562
self.conv4_phase = nn.Sequential(
563-
*conv_phase(1, 4, (5, 5), 1, 3, 2, add=1-pi)
563+
*conv_phase(1, 4, (5, 5), 1, 4, 2, add=1-pi)
564564
)
565565
self.conv5_amp = nn.Sequential(
566566
*conv_amp(4, 8, (5, 5), 1, 2, 1)
@@ -569,10 +569,10 @@ def __init__(self, img_size):
569569
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=1-pi)
570570
)
571571
self.conv6_amp = nn.Sequential(
572-
*conv_amp(8, 12, (3, 3), 1, 3, 2)
572+
*conv_amp(8, 12, (3, 3), 1, 2, 2)
573573
)
574574
self.conv6_phase = nn.Sequential(
575-
*conv_phase(8, 12, (3, 3), 1, 3, 2, add=1-pi)
575+
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
576576
)
577577
self.conv7_amp = nn.Sequential(
578578
*conv_amp(12, 16, (3, 3), 1, 1, 1)
@@ -710,13 +710,13 @@ def __init__(self, img_size):
710710
)
711711

712712
self.conv4_amp = nn.Sequential(
713-
*conv_amp(1, 4, (5, 5), 1, 3, 2)
713+
*conv_amp(1, 4, (5, 5), 1, 4, 2)
714714
)
715715
self.conv5_amp = nn.Sequential(
716716
*conv_amp(4, 8, (5, 5), 1, 2, 1)
717717
)
718718
self.conv6_amp = nn.Sequential(
719-
*conv_amp(8, 12, (3, 3), 1, 3, 2)
719+
*conv_amp(8, 12, (3, 3), 1, 2, 2)
720720
)
721721
self.conv7_amp = nn.Sequential(
722722
*conv_amp(12, 16, (3, 3), 1, 1, 1)
@@ -787,51 +787,51 @@ class filter_deep_phase(nn.Module):
787787
def __init__(self, img_size):
788788
super().__init__()
789789
self.conv1_phase = nn.Sequential(
790-
*conv_phase(1, 4, (23, 23), 1, 11, 1, add=-2.1415)
790+
*conv_phase(1, 4, (23, 23), 1, 11, 1, add=1-pi)
791791
)
792792
self.conv2_phase = nn.Sequential(
793-
*conv_phase(4, 8, (21, 21), 1, 10, 1, add=-2.1415)
793+
*conv_phase(4, 8, (21, 21), 1, 10, 1, add=1-pi)
794794
)
795795
self.conv3_phase = nn.Sequential(
796-
*conv_phase(8, 12, (17, 17), 1, 8, 1, add=-2.1415)
796+
*conv_phase(8, 12, (17, 17), 1, 8, 1, add=1-pi)
797797
)
798798
self.conv_con1_phase = nn.Sequential(
799799
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
800800
nn.BatchNorm2d(1),
801-
GeneralELU(-2.1415),
801+
GeneralELU(1-pi),
802802
)
803803

804804
self.conv4_phase = nn.Sequential(
805-
*conv_phase(1, 4, (5, 5), 1, 3, 2, add=-2.1415)
805+
*conv_phase(1, 4, (5, 5), 1, 4, 2, add=1-pi)
806806
)
807807
self.conv5_phase = nn.Sequential(
808-
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=-2.1415)
808+
*conv_phase(4, 8, (5, 5), 1, 2, 1, add=1-pi)
809809
)
810810
self.conv6_phase = nn.Sequential(
811-
*conv_phase(8, 12, (3, 3), 1, 3, 2, add=-2.1415)
811+
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
812812
)
813813
self.conv7_phase = nn.Sequential(
814-
*conv_phase(12, 16, (3, 3), 1, 1, 1, add=-2.1415)
814+
*conv_phase(12, 16, (3, 3), 1, 1, 1, add=1-pi)
815815
)
816816
self.conv_con2_phase = nn.Sequential(
817817
LocallyConnected2d(16, 1, img_size, 1, stride=1, bias=False),
818818
nn.BatchNorm2d(1),
819-
GeneralELU(-2.1415),
819+
GeneralELU(1-pi),
820820
)
821821

822822
self.conv8_phase = nn.Sequential(
823-
*conv_phase(1, 4, (3, 3), 1, 1, 1, add=-2.1415)
823+
*conv_phase(1, 4, (3, 3), 1, 1, 1, add=1-pi)
824824
)
825825
self.conv9_phase = nn.Sequential(
826-
*conv_phase(4, 8, (3, 3), 1, 1, 1, add=-2.1415)
826+
*conv_phase(4, 8, (3, 3), 1, 1, 1, add=1-pi)
827827
)
828828
self.conv10_phase = nn.Sequential(
829-
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=-2.1415)
829+
*conv_phase(8, 12, (3, 3), 1, 2, 2, add=1-pi)
830830
)
831831
self.conv_con3_phase = nn.Sequential(
832832
LocallyConnected2d(12, 1, img_size, 1, stride=1, bias=False),
833833
nn.BatchNorm2d(1),
834-
GeneralELU(-2.1415),
834+
GeneralELU(1-pi),
835835
)
836836
self.symmetry_imag = Lambda(partial(symmetry, mode='imag'))
837837

‎dl_framework/callbacks.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,22 @@ def plot_lr(self):
135135
plt.tight_layout()
136136

137137
def plot_loss(self, log=True):
138-
plt.plot(self.train_losses, label="train loss")
139-
plt.plot(self.valid_losses, label="valid loss")
138+
import matplotlib as mpl
139+
140+
# make nice Latex friendly plots
141+
# mpl.use("pgf")
142+
# mpl.rcParams.update(
143+
# {
144+
# "font.size": 12,
145+
# "font.family": "sans-serif",
146+
# "text.usetex": True,
147+
# "pgf.rcfonts": False,
148+
# "pgf.texsystem": "lualatex",
149+
# }
150+
# )
151+
152+
plt.plot(self.train_losses, label="training loss")
153+
plt.plot(self.valid_losses, label="validation loss")
140154
if log:
141155
plt.yscale("log")
142156
plt.xlabel(r"Number of Epochs")

‎dl_framework/data.py

+2
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def open_image(self, var, i):
108108

109109
data_channel = torch.cat([data_amp, data_phase], dim=1)
110110
else:
111+
if data.shape[1] == 2:
112+
raise ValueError("Two channeled data is used despite Fourier being False. Set Fourier to True!")
111113
if len(i) == 1:
112114
data_channel = data.reshape(data.shape[-1] ** 2)
113115
else:

‎dl_framework/inspection.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
import dl_framework.architectures as architecture
88
from dl_framework.model import load_pre_model
99

10+
# make nice Latex friendly plots
11+
# mpl.use("pgf")
12+
# mpl.rcParams.update(
13+
# {
14+
# "font.size": 12,
15+
# "font.family": "sans-serif",
16+
# "text.usetex": True,
17+
# "pgf.rcfonts": False,
18+
# "pgf.texsystem": "lualatex",
19+
# }
20+
# )
21+
1022

1123
def load_pretrained_model(arch_name, model_path):
1224
"""
@@ -121,7 +133,7 @@ def plot_loss(learn, model_path):
121133
save_path = model_path.split(".model")[0]
122134
print("\nPlotting Loss for: {}\n".format(name_model))
123135
learn.recorder.plot_loss()
124-
plt.title(r"{}".format(name_model))
136+
plt.title(r"{}".format(name_model.replace("_", " ")))
125137
plt.savefig("{}_loss.pdf".format(save_path), bbox_inches="tight", pad_inches=0.01)
126138
plt.clf()
127139
mpl.rcParams.update(mpl.rcParamsDefault)

‎dl_framework/learner.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
from tqdm import tqdm
88
import sys
99
from functools import partial
10-
from dl_framework.loss_functions import init_feature_loss, loss_amp, loss_phase
10+
from dl_framework.loss_functions import (
11+
init_feature_loss,
12+
loss_amp,
13+
loss_phase,
14+
loss_msssim,
15+
loss_mse_msssim,
16+
loss_mse_msssim_phase,
17+
loss_mse_msssim_amp,
18+
loss_msssim_amp,
19+
)
1120
from dl_framework.callbacks import (
1221
AvgStatsCallback,
1322
BatchTransformXCallback,
@@ -188,6 +197,7 @@ def define_learner(
188197
opt_func=torch.optim.Adam,
189198
):
190199
cbfs.extend([
200+
# commented out because of normed and limited input values
191201
# partial(BatchTransformXCallback, norm),
192202
])
193203
if not test:
@@ -202,7 +212,7 @@ def define_learner(
202212
])
203213
if not test and not lr_find:
204214
cbfs.extend([
205-
partial(LoggerCallback, model_name=model_name),
215+
partial(LoggerCallback, model_name=model_name),
206216
data_aug,
207217
])
208218

@@ -216,8 +226,18 @@ def define_learner(
216226
loss_func = loss_amp
217227
elif loss_func == "loss_phase":
218228
loss_func = loss_phase
229+
elif loss_func == "msssim":
230+
loss_func = loss_msssim
231+
elif loss_func == "mse_msssim":
232+
loss_func = loss_mse_msssim
233+
elif loss_func == "mse_msssim_phase":
234+
loss_func = loss_mse_msssim_phase
235+
elif loss_func == "mse_msssim_amp":
236+
loss_func = loss_mse_msssim_amp
237+
elif loss_func == "msssim_amp":
238+
loss_func = loss_msssim_amp
219239
else:
220-
print("\n No matching loss function! Exiting. \n")
240+
print("\n No matching loss function or architecture! Exiting. \n")
221241
sys.exit(1)
222242

223243
# Combine model and data in learner

0 commit comments

Comments
 (0)
Please sign in to comment.