1818from model .DCCNN import DCCNN
1919from model .LPDNet import LPDNet
2020from model .HQSNet import HQSNet
21-
21+ from model .ISTANet_plus import ISTANetplus
22+ import numpy as np
2223
2324class Solver ():
2425 def __init__ (self , args ):
26+ torch .autograd .set_detect_anomaly (True )
2527 self .args = args
2628 ################ experiment settings ################
2729 self .model_name = self .args .model
@@ -43,22 +45,24 @@ def __init__(self, args):
4345 os .makedirs (self .saveDir )
4446
4547 self .task_name = self .model_name + '_acc_' + str (self .acc ) + '_bs_' + str (self .batch_size ) \
46- + '_lr_' + str (self .lr )
48+ + '_lr_' + str (self .lr ) + 'bf_5_nocat' #first_dc'#'_iter_10'#'_bf=1' #+ _nocat '_bf=1'#
4749 print ('task_name: ' , self .task_name )
4850 self .model_path = 'weight/' + self .task_name + '_best.pth' # model load path for test and visualization
4951
5052 ############################################ Specify network ############################################
5153 if self .model_name == 'dc-cnn' :
52- self .net = DCCNN ()
54+ self .net = DCCNN (n_iter = 8 )
55+ elif self .model_name == 'ista-net-plus' :
56+ self .net = ISTANetplus (n_iter = 8 )
5357 elif self .model_name == 'lpd-net' :
54- self .net = LPDNet ()
58+ self .net = LPDNet (n_iter = 8 )
5559 elif self .model_name == 'hqs-net' :
56- self .net = HQSNet (block_type = 'cnn' )
60+ self .net = HQSNet (block_type = 'cnn' , buffer_size = 5 , n_iter = 8 )
5761 elif self .model_name == 'hqs-net-unet' :
58- self .net = HQSNet (block_type = 'unet' )
62+ self .net = HQSNet (block_type = 'unet' , n_iter = 10 )
5963 else :
6064 assert "wrong model name !"
61- print ('Total # of model params: %.5fM' % (sum (p .numel () for p in self .net .parameters ()) / ( 1024.0 * 1024 ) ))
65+ print ('Total # of model params: %.5fM' % (sum (p .numel () for p in self .net .parameters ()) / 10. ** 6 ))
6266 self .net .cuda ()
6367
6468 def train (self ):
@@ -70,10 +74,10 @@ def train(self):
7074 ## 2. we train the hqs-net-unet model with ssim + l1 loss, the reason is that, we found when using ms-ssim loss,
7175 ## the gradient of ms-ssim may be nan. This bug exists in both pytorch and tensoflow implementation of ms-ssim loss.
7276 ## see https://github.com/tensorflow/tensorflow/issues/50400, https://github.com/VainF/pytorch-msssim/issues/12
73- if self .model_name == 'hqs-net-unet' :
74- self .criterion = CompoundLoss ('ssim' )
75- else :
76- self .criterion = CompoundLoss ('ms-ssim' )
77+ # if self.model_name == 'hqs-net-unet':
78+ # self.criterion = CompoundLoss('ssim')
79+ # else:
80+ self .criterion = CompoundLoss ('ms-ssim' )
7781
7882 ############################################ Specify optimizer ########################################
7983
@@ -84,10 +88,12 @@ def train(self):
8488 dataset_train = MyData (self .imageDir_train , self .acc , self .img_size , is_training = 'train' )
8589 dataset_val = MyData (self .imageDir_val , self .acc , self .img_size , is_training = 'val' )
8690
91+ num_workers = 4
92+ use_pin_memory = True
8793 loader_train = Data .DataLoader (dataset_train , batch_size = self .batch_size , shuffle = True , drop_last = True ,
88- num_workers = 4 , pin_memory = True )
94+ num_workers = num_workers , pin_memory = use_pin_memory )
8995 loader_val = Data .DataLoader (dataset_val , batch_size = self .batch_size , shuffle = False , drop_last = False ,
90- num_workers = 4 , pin_memory = True )
96+ num_workers = num_workers , pin_memory = use_pin_memory )
9197 self .slices_val = len (dataset_val )
9298 print ("slices of 2d train data: " , len (dataset_train ))
9399 print ("slices of 2d validation data: " , len (dataset_val ))
@@ -99,16 +105,22 @@ def train(self):
99105
100106 start_epoch = 0
101107 best_val_psnr = 0
108+ if 0 :
109+ best_name = self .task_name + '_best.pth'
110+ checkpoint = torch .load (join (self .saveDir , best_name ))
111+ self .net .load_state_dict (checkpoint ['net' ])
112+ start_epoch = checkpoint ['epoch' ]+ 1
113+ best_val_psnr = checkpoint ['val_psnr' ]
114+ print ('load pretrained model---, start epoch at, ' ,start_epoch , ', star_psnr_val is: ' ,best_val_psnr )
102115 for epoch in range (start_epoch , self .num_epoch ):
103116 ####################### 1. training #######################
104117
105118 loss_g = self ._train_cnn (loader_train )
106119 ####################### 2. validate #######################
120+ if epoch == start_epoch :
121+ base_psnr , base_ssim = self ._validate_base (loader_val )
107122 if epoch % self .val_on_epochs == 0 :
108- if epoch == 0 :
109- base_psnr , base_ssim = self ._validate_base (loader_val )
110123 val_psnr , val_ssim = self ._validate (loader_val )
111-
112124 ########################## 3. print and tensorboard ########################
113125 print ("Epoch {}/{}" .format (epoch + 1 , self .num_epoch ))
114126 print (" base PSNR:\t \t {:.6f}" .format (base_psnr ))
@@ -148,19 +160,23 @@ def test(self):
148160 self .net .cuda ()
149161 self .net .eval ()
150162
151- base_psnr = 0
152- test_psnr = 0
153- base_ssim = 0
154- test_ssim = 0
155- base_nrmse = 0
156- test_nrmse = 0
163+ base_psnr = []
164+ test_psnr = []
165+ base_ssim = []
166+ test_ssim = []
167+ base_nrmse = []
168+ test_nrmse = []
157169 with torch .no_grad ():
158170 time_0 = time .time ()
159171 for data_dict in tqdm (loader_val ):
160172 im_A , im_A_und , k_A_und , mask = data_dict ['im_A' ].float ().cuda (), data_dict ['im_A_und' ].float ().cuda (), \
161173 data_dict ['k_A_und' ].float ().cuda (), \
162174 data_dict ['mask_A' ].float ().cuda ()
163- T1 = self .net (im_A_und , k_A_und , mask )
175+
176+ if self .model_name == 'ista-net-plus' :
177+ T1 , loss_layers_sym = self .net (im_A_und , k_A_und , mask )
178+ else :
179+ T1 = self .net (im_A_und , k_A_und , mask )
164180 ############## convert model ouput to complex value in original range
165181
166182 T1 = output2complex (T1 )
@@ -170,42 +186,45 @@ def test(self):
170186 ########################### calulate metrics ###################################
171187 for T1_i , im_A_i , im_A_und_i in zip (T1 .cpu ().numpy (), im_A .cpu ().numpy (), im_A_und .cpu ().numpy ()):
172188 ## for skimage.metrics, input is (im_true,im_pred)
173- base_nrmse += cal_nrmse (im_A_i , im_A_und_i )
174- test_nrmse += cal_nrmse (im_A_i , T1_i )
175- base_ssim += cal_ssim (im_A_i , im_A_und_i )
176- test_ssim += cal_ssim (im_A_i , T1_i )
177- base_psnr += cal_psnr (im_A_i , im_A_und_i , data_range = im_A_i .max ())
178- test_psnr += cal_psnr (im_A_i , T1_i , data_range = im_A_i .max ())
189+ base_nrmse . append ( cal_nrmse (im_A_i , im_A_und_i ) )
190+ test_nrmse . append ( cal_nrmse (im_A_i , T1_i ) )
191+ base_ssim . append ( cal_ssim (im_A_i , im_A_und_i , data_range = im_A_i . max ()) )
192+ test_ssim . append ( cal_ssim (im_A_i , T1_i , data_range = im_A_i . max ()) )
193+ base_psnr . append ( cal_psnr (im_A_i , im_A_und_i , data_range = im_A_i .max () ))
194+ test_psnr . append ( cal_psnr (im_A_i , T1_i , data_range = im_A_i .max () ))
179195
180196 time_1 = time .time ()
181197 ## comment metric calculation code for more precise inference speed
182198 print ('inference speed: {:.5f} ms/slice' .format (1000 * (time_1 - time_0 ) / len_data ))
183- base_psnr /= len_data
184- test_psnr /= len_data
185- base_ssim /= len_data
186- test_ssim /= len_data
187- base_nrmse /= len_data
188- test_nrmse /= len_data
189-
190- print (" base PSNR:\t \t {:.6f}" .format (base_psnr ))
191- print (" test PSNR:\t \t {:.6f}" .format (test_psnr ))
192- print (" base SSIM:\t \t {:.6f}" .format (base_ssim ))
193- print (" test SSIM:\t \t {:.6f}" .format (test_ssim ))
194- print (" base NRMSE:\t \t {:.6f}" .format (base_nrmse ))
195- print (" test NRMSE:\t \t {:.6f}" .format (test_nrmse ))
199+
200+ print (" base PSNR:\t \t {:.6f}, std: {:.6f}" .format (np .mean (base_psnr ),np .std (base_psnr )))
201+ print (" test PSNR:\t \t {:.6f}, std: {:.6f}" .format (np .mean (test_psnr ),np .std (test_psnr )))
202+ print (" base SSIM:\t \t {:.6f}, std: {:.6f}" .format (np .mean (base_ssim ),np .std (base_ssim )))
203+ print (" test SSIM:\t \t {:.6f}, std: {:.6f}" .format (np .mean (test_ssim ),np .std (test_ssim )))
204+ print (" base NRMSE:\t \t {:.6f}, std: {:.6f}" .format (np .mean (base_nrmse ),np .std (base_nrmse )))
205+ print (" test NRMSE:\t \t {:.6f}, std: {:.6f}" .format (np .mean (test_nrmse ),np .std (test_nrmse )))
196206
197207 def _train_cnn (self , loader_train ):
198208 self .net .train ()
199209 for data_dict in tqdm (loader_train ):
200210 im_A , im_A_und , k_A_und , mask = data_dict ['im_A' ].float ().cuda (), data_dict [
201211 'im_A_und' ].float ().cuda (), data_dict ['k_A_und' ].float ().cuda (), data_dict ['mask_A' ].float ().cuda ()
202- T1 = self .net (im_A_und , k_A_und , mask )
212+ if self .model_name == 'ista-net-plus' :
213+ T1 ,loss_layers_sym = self .net (im_A_und , k_A_und , mask )
214+ else :
215+ T1 = self .net (im_A_und , k_A_und , mask )
203216
204217 T1 = output2complex (T1 )
205218 im_A = output2complex (im_A )
206219 ############################################# 1.2 update generator #############################################
207220
208221 loss_g = self .criterion (T1 , im_A , data_range = im_A .max ())
222+ if self .model_name == 'ista-net-plus' :
223+ loss_constraint = torch .mean (torch .pow (loss_layers_sym [0 ], 2 ))
224+ for k in range (len (loss_layers_sym )- 1 ):
225+ loss_constraint += torch .mean (torch .pow (loss_layers_sym [k + 1 ], 2 ))
226+ loss_g = loss_g + 0.01 * loss_constraint
227+
209228 self .optimizer_G .zero_grad ()
210229 loss_g .backward ()
211230 self .optimizer_G .step ()
@@ -244,7 +263,10 @@ def _validate(self, loader_val):
244263 im_A , im_A_und , k_A_und , mask = data_dict ['im_A' ].float ().cuda (), data_dict [
245264 'im_A_und' ].float ().cuda (), data_dict ['k_A_und' ].float ().cuda (), data_dict [
246265 'mask_A' ].float ().cuda ()
247- T1 = self .net (im_A_und , k_A_und , mask )
266+ if self .model_name == 'ista-net-plus' :
267+ T1 ,_ = self .net (im_A_und , k_A_und , mask )
268+ else :
269+ T1 = self .net (im_A_und , k_A_und , mask )
248270 ############## convert model ouput to complex value in original range
249271 T1 = output2complex (T1 )
250272 im_A = output2complex (im_A )
0 commit comments