Skip to content

Commit c29814f

Browse files
committed
updates after rebuttal: DC-CNN changed to soft format, add ISTA-Net+
1 parent e779ce7 commit c29814f

34 files changed

Lines changed: 406 additions & 65 deletions

Solver.py

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from model.DCCNN import DCCNN
1919
from model.LPDNet import LPDNet
2020
from model.HQSNet import HQSNet
21-
21+
from model.ISTANet_plus import ISTANetplus
22+
import numpy as np
2223

2324
class Solver():
2425
def __init__(self, args):
26+
torch.autograd.set_detect_anomaly(True)
2527
self.args = args
2628
################ experiment settings ################
2729
self.model_name = self.args.model
@@ -43,22 +45,24 @@ def __init__(self, args):
4345
os.makedirs(self.saveDir)
4446

4547
self.task_name = self.model_name + '_acc_' + str(self.acc) + '_bs_' + str(self.batch_size) \
46-
+ '_lr_' + str(self.lr)
48+
+ '_lr_' + str(self.lr) + 'bf_5_nocat' #first_dc'#'_iter_10'#'_bf=1' #+ _nocat '_bf=1'#
4749
print('task_name: ', self.task_name)
4850
self.model_path = 'weight/' + self.task_name + '_best.pth' # model load path for test and visualization
4951

5052
############################################ Specify network ############################################
5153
if self.model_name == 'dc-cnn':
52-
self.net = DCCNN()
54+
self.net = DCCNN(n_iter=8)
55+
elif self.model_name == 'ista-net-plus':
56+
self.net = ISTANetplus(n_iter=8)
5357
elif self.model_name == 'lpd-net':
54-
self.net = LPDNet()
58+
self.net = LPDNet(n_iter=8)
5559
elif self.model_name == 'hqs-net':
56-
self.net = HQSNet(block_type='cnn')
60+
self.net = HQSNet(block_type='cnn',buffer_size=5, n_iter=8)
5761
elif self.model_name == 'hqs-net-unet':
58-
self.net = HQSNet(block_type='unet')
62+
self.net = HQSNet(block_type='unet', n_iter=10)
5963
else:
6064
assert "wrong model name !"
61-
print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / (1024.0 * 1024)))
65+
print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / 10.**6))
6266
self.net.cuda()
6367

6468
def train(self):
@@ -70,10 +74,10 @@ def train(self):
7074
## 2. we train the hqs-net-unet model with ssim + l1 loss, the reason is that, we found when using ms-ssim loss,
7175
## the gradient of ms-ssim may be nan. This bug exists in both pytorch and tensoflow implementation of ms-ssim loss.
7276
## see https://github.com/tensorflow/tensorflow/issues/50400, https://github.com/VainF/pytorch-msssim/issues/12
73-
if self.model_name == 'hqs-net-unet':
74-
self.criterion = CompoundLoss('ssim')
75-
else:
76-
self.criterion = CompoundLoss('ms-ssim')
77+
# if self.model_name == 'hqs-net-unet':
78+
# self.criterion = CompoundLoss('ssim')
79+
# else:
80+
self.criterion = CompoundLoss('ms-ssim')
7781

7882
############################################ Specify optimizer ########################################
7983

@@ -84,10 +88,12 @@ def train(self):
8488
dataset_train = MyData(self.imageDir_train, self.acc, self.img_size, is_training='train')
8589
dataset_val = MyData(self.imageDir_val, self.acc, self.img_size, is_training='val')
8690

91+
num_workers = 4
92+
use_pin_memory = True
8793
loader_train = Data.DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True,
88-
num_workers=4, pin_memory=True)
94+
num_workers=num_workers, pin_memory=use_pin_memory)
8995
loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False,
90-
num_workers=4, pin_memory=True)
96+
num_workers=num_workers, pin_memory=use_pin_memory)
9197
self.slices_val = len(dataset_val)
9298
print("slices of 2d train data: ", len(dataset_train))
9399
print("slices of 2d validation data: ", len(dataset_val))
@@ -99,16 +105,22 @@ def train(self):
99105

100106
start_epoch = 0
101107
best_val_psnr = 0
108+
if 0:
109+
best_name = self.task_name + '_best.pth'
110+
checkpoint = torch.load(join(self.saveDir, best_name))
111+
self.net.load_state_dict(checkpoint['net'])
112+
start_epoch = checkpoint['epoch']+1
113+
best_val_psnr = checkpoint['val_psnr']
114+
print('load pretrained model---, start epoch at, ',start_epoch, ', star_psnr_val is: ',best_val_psnr)
102115
for epoch in range(start_epoch, self.num_epoch):
103116
####################### 1. training #######################
104117

105118
loss_g = self._train_cnn(loader_train)
106119
####################### 2. validate #######################
120+
if epoch == start_epoch:
121+
base_psnr, base_ssim = self._validate_base(loader_val)
107122
if epoch % self.val_on_epochs == 0:
108-
if epoch == 0:
109-
base_psnr, base_ssim = self._validate_base(loader_val)
110123
val_psnr, val_ssim = self._validate(loader_val)
111-
112124
########################## 3. print and tensorboard ########################
113125
print("Epoch {}/{}".format(epoch + 1, self.num_epoch))
114126
print(" base PSNR:\t\t{:.6f}".format(base_psnr))
@@ -148,19 +160,23 @@ def test(self):
148160
self.net.cuda()
149161
self.net.eval()
150162

151-
base_psnr = 0
152-
test_psnr = 0
153-
base_ssim = 0
154-
test_ssim = 0
155-
base_nrmse = 0
156-
test_nrmse = 0
163+
base_psnr = []
164+
test_psnr = []
165+
base_ssim = []
166+
test_ssim = []
167+
base_nrmse = []
168+
test_nrmse = []
157169
with torch.no_grad():
158170
time_0 = time.time()
159171
for data_dict in tqdm(loader_val):
160172
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda(), \
161173
data_dict['k_A_und'].float().cuda(), \
162174
data_dict['mask_A'].float().cuda()
163-
T1 = self.net(im_A_und, k_A_und, mask)
175+
176+
if self.model_name == 'ista-net-plus':
177+
T1, loss_layers_sym = self.net(im_A_und, k_A_und, mask)
178+
else:
179+
T1 = self.net(im_A_und, k_A_und, mask)
164180
############## convert model ouput to complex value in original range
165181

166182
T1 = output2complex(T1)
@@ -170,42 +186,45 @@ def test(self):
170186
########################### calulate metrics ###################################
171187
for T1_i, im_A_i, im_A_und_i in zip(T1.cpu().numpy(), im_A.cpu().numpy(), im_A_und.cpu().numpy()):
172188
## for skimage.metrics, input is (im_true,im_pred)
173-
base_nrmse += cal_nrmse(im_A_i, im_A_und_i)
174-
test_nrmse += cal_nrmse(im_A_i, T1_i)
175-
base_ssim += cal_ssim(im_A_i, im_A_und_i)
176-
test_ssim += cal_ssim(im_A_i, T1_i)
177-
base_psnr += cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max())
178-
test_psnr += cal_psnr(im_A_i, T1_i, data_range=im_A_i.max())
189+
base_nrmse.append(cal_nrmse(im_A_i, im_A_und_i))
190+
test_nrmse.append(cal_nrmse(im_A_i, T1_i))
191+
base_ssim.append(cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max()))
192+
test_ssim.append(cal_ssim(im_A_i, T1_i, data_range=im_A_i.max()))
193+
base_psnr.append(cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max()))
194+
test_psnr.append(cal_psnr(im_A_i, T1_i, data_range=im_A_i.max()))
179195

180196
time_1 = time.time()
181197
## comment metric calculation code for more precise inference speed
182198
print('inference speed: {:.5f} ms/slice'.format(1000 * (time_1 - time_0) / len_data))
183-
base_psnr /= len_data
184-
test_psnr /= len_data
185-
base_ssim /= len_data
186-
test_ssim /= len_data
187-
base_nrmse /= len_data
188-
test_nrmse /= len_data
189-
190-
print(" base PSNR:\t\t{:.6f}".format(base_psnr))
191-
print(" test PSNR:\t\t{:.6f}".format(test_psnr))
192-
print(" base SSIM:\t\t{:.6f}".format(base_ssim))
193-
print(" test SSIM:\t\t{:.6f}".format(test_ssim))
194-
print(" base NRMSE:\t\t{:.6f}".format(base_nrmse))
195-
print(" test NRMSE:\t\t{:.6f}".format(test_nrmse))
199+
200+
print(" base PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_psnr),np.std(base_psnr)))
201+
print(" test PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_psnr),np.std(test_psnr)))
202+
print(" base SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_ssim),np.std(base_ssim)))
203+
print(" test SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_ssim),np.std(test_ssim)))
204+
print(" base NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_nrmse),np.std(base_nrmse)))
205+
print(" test NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_nrmse),np.std(test_nrmse)))
196206

197207
def _train_cnn(self, loader_train):
198208
self.net.train()
199209
for data_dict in tqdm(loader_train):
200210
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
201211
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict['mask_A'].float().cuda()
202-
T1 = self.net(im_A_und, k_A_und, mask)
212+
if self.model_name == 'ista-net-plus':
213+
T1,loss_layers_sym = self.net(im_A_und, k_A_und, mask)
214+
else:
215+
T1 = self.net(im_A_und, k_A_und, mask)
203216

204217
T1 = output2complex(T1)
205218
im_A = output2complex(im_A)
206219
############################################# 1.2 update generator #############################################
207220

208221
loss_g = self.criterion(T1, im_A, data_range=im_A.max())
222+
if self.model_name == 'ista-net-plus':
223+
loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
224+
for k in range(len(loss_layers_sym)-1):
225+
loss_constraint += torch.mean(torch.pow(loss_layers_sym[k + 1], 2))
226+
loss_g = loss_g + 0.01 * loss_constraint
227+
209228
self.optimizer_G.zero_grad()
210229
loss_g.backward()
211230
self.optimizer_G.step()
@@ -244,7 +263,10 @@ def _validate(self, loader_val):
244263
im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[
245264
'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict[
246265
'mask_A'].float().cuda()
247-
T1 = self.net(im_A_und, k_A_und, mask)
266+
if self.model_name == 'ista-net-plus':
267+
T1,_ = self.net(im_A_und, k_A_und, mask)
268+
else:
269+
T1 = self.net(im_A_und, k_A_und, mask)
248270
############## convert model ouput to complex value in original range
249271
T1 = output2complex(T1)
250272
im_A = output2complex(im_A)

dd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# @author : Bingyu Xin
2+
# @Institute : CS@Rutgers

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def main(args):
1919
############################### experiment settings ##########################
2020
parser.add_argument('--mode', default='train', choices=['train', 'test'],
2121
help='mode for the program')
22-
parser.add_argument('--model', default='hqs-net', choices=['dc-cnn', 'lpd-net', 'hqs-net', 'hqs-net-unet'],
22+
parser.add_argument('--model', default='hqs-net', choices=['dc-cnn', 'lpd-net', 'hqs-net', 'hqs-net-unet','ista-net-plus'],
2323
help='models to reconstruct')
2424
parser.add_argument('--acc', type=int, default=5,
2525
help='Acceleration factor for k-space sampling')

model/BasicModule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def conv_block(model_name='hqs-net', channel_in=22, n_convs=3, n_filters=32):
2020
layers = []
2121
if model_name == 'dc-cnn':
2222
channel_out = channel_in
23+
if model_name == 'ista-net':
24+
channel_out = n_filters
2325
elif model_name == 'prim-net' or model_name == 'hqs-net':
2426
channel_out = channel_in - 2
2527
elif model_name == 'dual-net':

model/DCCNN.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class DCCNN(nn.Module):
9-
def __init__(self, n_iter=5, n_convs=5, n_filters=64, norm='ortho'):
9+
def __init__(self, n_iter=8, n_convs=6, n_filters=64, norm='ortho'):
1010
'''
1111
DC-CNN modified from paper " A Deep Cascade of Convolutional Neural Networks for Dynamic MR Image Reconstruction "
1212
( https://arxiv.org/pdf/1704.02422.pdf ) ( https://github.com/js3611/Deep-MRI-Reconstruction )
@@ -19,6 +19,7 @@ def __init__(self, n_iter=5, n_convs=5, n_filters=64, norm='ortho'):
1919
channel_in = 2
2020
rec_blocks = []
2121
self.norm = norm
22+
self.mu = nn.Parameter(torch.Tensor([0.5]))
2223
self.n_iter = n_iter
2324
for i in range(n_iter):
2425
rec_blocks.append(conv_block('dc-cnn', channel_in, n_filters=n_filters, n_convs=n_convs))
@@ -32,18 +33,39 @@ def dc_operation(self, x_rec, k_un, mask):
3233
k_rec = torch.fft.fft2(torch.view_as_complex(x_rec.contiguous()), norm=self.norm)
3334

3435
k_rec = torch.view_as_real(k_rec)
36+
# noiseless
3537
k_out = k_rec + (k_un - k_rec) * mask
3638

3739
k_out = torch.view_as_complex(k_out)
3840
x_out = torch.view_as_real(torch.fft.ifft2(k_out, norm=self.norm))
3941
x_out = x_out.permute(0, 3, 1, 2)
4042
return x_out
43+
def _forward_operation(self, img, mask):
44+
45+
k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()),
46+
norm=self.norm)
47+
k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous()
48+
k = mask * k
49+
return k
50+
51+
def _backward_operation(self, k, mask):
52+
53+
k = mask * k
54+
img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm)
55+
img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous()
56+
return img
57+
58+
def update_opration(self, f_1, k, mask):
59+
h_1 = k - self._forward_operation(f_1, mask)
60+
update = f_1 + self.mu * self._backward_operation(h_1, mask)
61+
return update
4162

4263
def forward(self, x, k, m):
4364
for i in range(self.n_iter):
65+
# x = self.update_opration(x, k, m)
4466
x_cnn = self.rec_blocks[i](x)
4567
x = x + x_cnn
46-
x = self.dc_operation(x, k, m)
68+
x = self.update_opration(x, k, m)
4769
return x
4870

4971

model/HQSNet.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class HQSNet(nn.Module):
10-
def __init__(self, buffer_size=5, n_iter=10, n_convs=3, n_filters=32, block_type='cnn', norm='ortho'):
10+
def __init__(self, buffer_size=5, n_iter=8, n_convs=6, n_filters=64, block_type='cnn', norm='ortho'):
1111
'''
1212
HQS-Net
1313
:param buffer_size: buffer_size m
@@ -20,13 +20,13 @@ def __init__(self, buffer_size=5, n_iter=10, n_convs=3, n_filters=32, block_type
2020
self.m = buffer_size
2121
self.n_iter = n_iter
2222
## the initialization of mu may influence the final accuracy
23-
self.mu = nn.Parameter(2. * torch.ones((1, 1)))
23+
self.mu = nn.Parameter(0.5 * torch.ones((1, 1))) #2
2424
self.block_type = block_type
2525
if self.block_type == 'cnn':
2626
rec_blocks = []
2727
for i in range(self.n_iter):
2828
rec_blocks.append(
29-
conv_block('hqs-net', channel_in=2 * (self.m + 1), n_convs=n_convs, n_filters=n_filters))
29+
conv_block('hqs-net', channel_in=2 * (self.m+1 ), n_convs=n_convs, n_filters=n_filters)) #self.m +
3030
self.rec_blocks = nn.ModuleList(rec_blocks)
3131
elif self.block_type == 'unet':
3232
self.rec_blocks = UNetRes(in_nc=2 * (self.m + 1), out_nc=2 * self.m, nc=[64, 128, 256, 512], nb=4,
@@ -56,14 +56,27 @@ def update_opration(self, f_1, k, mask):
5656
def forward(self, img, k, mask):
5757

5858
## initialize buffer f : the concatenation of m copies of the complex-valued zero-filled images
59+
5960
f = torch.cat([img] * self.m, 1).to(img.device)
6061

62+
## n reconstruction blocks buff=5_nocat
63+
# for i in range(self.n_iter):
64+
# for j in range(self.m):
65+
# f_1 = f[:, j*2:j*2+2].clone()
66+
# f[:, j*2:j*2+2] = self.update_opration(f_1, k, mask)
67+
# if self.block_type == 'cnn':
68+
# # f = f + self.rec_blocks[i](torch.cat([f, updated_f_1], 1))
69+
# f = f + self.rec_blocks[i](f)
70+
# elif self.block_type == 'unet':
71+
# f = f + self.rec_blocks(torch.cat([f, updated_f_1], 1))
72+
6173
## n reconstruction blocks
6274
for i in range(self.n_iter):
6375
f_1 = f[:, 0:2].clone()
6476
updated_f_1 = self.update_opration(f_1, k, mask)
6577
if self.block_type == 'cnn':
6678
f = f + self.rec_blocks[i](torch.cat([f, updated_f_1], 1))
79+
# f = updated_f_1 + self.rec_blocks[i](updated_f_1)
6780
elif self.block_type == 'unet':
6881
f = f + self.rec_blocks(torch.cat([f, updated_f_1], 1))
6982
return f[:, 0:2]

0 commit comments

Comments
 (0)