|
| 1 | +############################################################## |
| 2 | +# @author: pkc |
| 3 | +# |
| 4 | +# banded_plots.py |
| 5 | +# ............ |
| 6 | +# includes codes to output different banded plots shown in the |
| 7 | +# sfrc paper |
| 8 | +# comment the line mpi.use('Agg') in the file src/plot_func.py |
| 9 | +# to view plot as display |
| 10 | + |
| 11 | +import sys |
| 12 | +sys.path.append('../') |
| 13 | +import numpy as np |
| 14 | +import numpy.fft as fft |
| 15 | +from src import frc_utils |
| 16 | +from src import io_func |
| 17 | +from src import frc_utils |
| 18 | +from src import plot_func as pf |
| 19 | +import os |
| 20 | +from src import utils |
| 21 | + |
| 22 | +def get3_comp_of_img(img, band1_ep, band2_ep, band3_ep): |
| 23 | + h, w = img.shape |
| 24 | + img_p1 = np.zeros([h,w]).astype('complex') |
| 25 | + img_p2 = np.zeros([h,w]).astype('complex') |
| 26 | + img_p3 = np.zeros([h,w]).astype('complex') |
| 27 | + img_fft = fft.fftshift(fft.fft2(img)) |
| 28 | + |
| 29 | + indices=frc_utils.ring_indices(img_fft) |
| 30 | + for i in range(0, band1_ep): |
| 31 | + img_p1[indices[i]]=img_fft[indices[i]] |
| 32 | + for i in range(band1_ep, band2_ep): |
| 33 | + img_p2[indices[i]]=img_fft[indices[i]] |
| 34 | + for i in range(band2_ep, band3_ep): |
| 35 | + img_p3[indices[i]]=img_fft[indices[i]] |
| 36 | + |
| 37 | + img_p1 = np.real(fft.ifft2(fft.ifftshift(img_p1))) |
| 38 | + img_p2 = np.real(fft.ifft2(fft.ifftshift(img_p2))) |
| 39 | + img_p3 = np.real(fft.ifft2(fft.ifftshift(img_p3))) |
| 40 | + img_stack = np.stack((img.reshape(h,w), img_p1.reshape(h, w), img_p2.reshape(h, w), img_p3.reshape(h, w)), axis=0) |
| 41 | + return(img_p1, img_p2, img_p3, img_stack) |
| 42 | + |
| 43 | +def get3_fft_bands_of_img(img, band1_ep, band2_ep, band3_ep): |
| 44 | + h, w = img.shape |
| 45 | + img_p1 = np.zeros([h,w]) |
| 46 | + img_p2 = np.zeros([h,w]) |
| 47 | + img_p3 = np.zeros([h,w]) |
| 48 | + img_fft = np.log(np.abs(fft.fftshift(fft.fft2(img)))) |
| 49 | + |
| 50 | + indices=frc_utils.ring_indices(img_fft) |
| 51 | + for i in range(0, band1_ep): |
| 52 | + img_p1[indices[i]]=img_fft[indices[i]] |
| 53 | + for i in range(band1_ep, band2_ep): |
| 54 | + img_p2[indices[i]]=img_fft[indices[i]] |
| 55 | + for i in range(band2_ep, band3_ep): |
| 56 | + img_p3[indices[i]]=img_fft[indices[i]] |
| 57 | + img_stack = np.stack((img_fft.reshape(h,w), img_p1.reshape(h, w), img_p2.reshape(h, w), img_p3.reshape(h, w)), axis=0) |
| 58 | + return(img_p1, img_p2, img_p3, img_stack) |
| 59 | + |
| 60 | +def get5_comp_of_img(img, band1_ep, band2_ep, band3_ep, band4_ep, band5_ep): |
| 61 | + h, w = img.shape |
| 62 | + img_p1 = np.zeros([h,w]).astype('complex') |
| 63 | + img_p2 = np.zeros([h,w]).astype('complex') |
| 64 | + img_p3 = np.zeros([h,w]).astype('complex') |
| 65 | + img_p4 = np.zeros([h,w]).astype('complex') |
| 66 | + img_p5 = np.zeros([h,w]).astype('complex') |
| 67 | + img_fft = fft.fftshift(fft.fft2(img)) |
| 68 | + |
| 69 | + indices=frc_utils.ring_indices(img_fft) |
| 70 | + for i in range(0, band1_ep): |
| 71 | + img_p1[indices[i]]=img_fft[indices[i]] |
| 72 | + for i in range(band1_ep, band2_ep): |
| 73 | + img_p2[indices[i]]=img_fft[indices[i]] |
| 74 | + for i in range(band2_ep, band3_ep): |
| 75 | + img_p3[indices[i]]=img_fft[indices[i]] |
| 76 | + for i in range(band3_ep, band4_ep): |
| 77 | + img_p4[indices[i]]=img_fft[indices[i]] |
| 78 | + for i in range(band4_ep, band5_ep): |
| 79 | + img_p5[indices[i]]=img_fft[indices[i]] |
| 80 | + |
| 81 | + img_p1 = np.real(fft.ifft2(fft.ifftshift(img_p1))) |
| 82 | + img_p2 = np.real(fft.ifft2(fft.ifftshift(img_p2))) |
| 83 | + img_p3 = np.real(fft.ifft2(fft.ifftshift(img_p3))) |
| 84 | + img_p4 = np.real(fft.ifft2(fft.ifftshift(img_p4))) |
| 85 | + img_p5 = np.real(fft.ifft2(fft.ifftshift(img_p5))) |
| 86 | + img_stack = np.stack((img.reshape(h,w), img_p1.reshape(h, w), img_p2.reshape(h, w), img_p3.reshape(h, w), img_p4.reshape(h, w), img_p5.reshape(h, w)), axis=0) |
| 87 | + return(img_p1, img_p2, img_p3, img_p4, img_p5, img_stack) |
| 88 | + |
| 89 | +def get5_fft_bands_of_img(img, band1_ep, band2_ep, band3_ep, band4_ep, band5_ep): |
| 90 | + h, w = img.shape |
| 91 | + img_p1 = np.zeros([h,w]) |
| 92 | + img_p2 = np.zeros([h,w]) |
| 93 | + img_p3 = np.zeros([h,w]) |
| 94 | + img_p4 = np.zeros([h,w]) |
| 95 | + img_p5 = np.zeros([h,w]) |
| 96 | + img_fft = np.log(np.abs(fft.fftshift(fft.fft2(img)))) |
| 97 | + |
| 98 | + indices=frc_utils.ring_indices(img_fft) |
| 99 | + for i in range(0, band1_ep): |
| 100 | + img_p1[indices[i]]=img_fft[indices[i]] |
| 101 | + for i in range(band1_ep, band2_ep): |
| 102 | + img_p2[indices[i]]=img_fft[indices[i]] |
| 103 | + for i in range(band2_ep, band3_ep): |
| 104 | + img_p3[indices[i]]=img_fft[indices[i]] |
| 105 | + for i in range(band3_ep, band4_ep): |
| 106 | + img_p4[indices[i]]=img_fft[indices[i]] |
| 107 | + for i in range(band4_ep, band5_ep): |
| 108 | + img_p5[indices[i]]=img_fft[indices[i]] |
| 109 | + img_stack = np.stack((img_fft.reshape(h,w), img_p1.reshape(h, w), img_p2.reshape(h, w), img_p3.reshape(h, w), img_p4.reshape(h, w), img_p5.reshape(h, w) ), axis=0) |
| 110 | + return(img_p1, img_p2, img_p3, img_p4, img_p5, img_stack) |
| 111 | +# -------------------------------------- |
| 112 | +# Display and saving figure option |
| 113 | +# ----------------------------------------- |
| 114 | +plot_fig = True # Display plots |
| 115 | +save_fig = False # save plots |
| 116 | +crop_fig = True # banded plots related cropped image in the main paper vs full image in the supp paper |
| 117 | +Ncomp = 5 # no. of binned fft components. Other option is 3 |
| 118 | + |
| 119 | +if crop_fig: |
| 120 | + # ----------------------------------------------------------------- |
| 121 | + # this option corresponds to banded plots of the cropped ROIs |
| 122 | + # between FBP and SRGAN shown in the main paper. |
| 123 | + # cnn is pointed to SRGAN-based outputs |
| 124 | + # ----------------------------------------------------------------- |
| 125 | + gt_path = './plot2/crop_img_uint8_L_50_W_400/gt_000069.png' |
| 126 | + cnn_path = './plot2/crop_img_uint8_L_50_W_400/srgan_000069.png' #SRGAN outputs |
| 127 | + out_path = './plot2/crop_fig/Ncomp_'+ str(Ncomp)+ '/' |
| 128 | +else: |
| 129 | + # ----------------------------------------------------------------- |
| 130 | + # this option corresponds to the banded plots of full images |
| 131 | + # related to conventional artifacts showns in supplemental document. |
| 132 | + # cnn is pointed to missing-wedge-based outputs. |
| 133 | + # ----------------------------------------------------------------- |
| 134 | + #gt_path = '../nfk_artifacts/data/missing-wedge/gt_irt/mk_L291_tv_000163_uint16.tif' |
| 135 | + #cnn_path = '../nfk_artifacts/data/missing-wedge/input_irt/recon_theta_pm60_spac_2_uint16.tif' |
| 136 | + gt_path = './plot2/full_img_uint8_L_1256_W_780/mk_L291_tv_000163_L_1286_W_780_2_uint8.png' |
| 137 | + cnn_path = './plot2/full_img_uint8_L_1256_W_780/recon_theta_pm60_spac_2_L_1286_W_780_2_uint8.png' # missing_wedge |
| 138 | + dist_path = './plot2/full_img_uint8_L_1256_W_780/recon_theta_deg10_spac_0.5_L_1286_W_780_2_uint8.png' # distortion |
| 139 | + out_path = './plot2/full_fig/Ncomp_'+ str(Ncomp)+ '_uint8/' |
| 140 | + |
| 141 | +if not os.path.isdir(out_path): os.makedirs(out_path, ) |
| 142 | +if not crop_fig: dist_img = io_func.imageio_imread(dist_path) |
| 143 | + |
| 144 | +gt_img = io_func.imageio_imread(gt_path) |
| 145 | +cnn_img = io_func.imageio_imread(cnn_path) |
| 146 | +h, w = gt_img.shape |
| 147 | +r = int(h/2) |
| 148 | + |
| 149 | +if Ncomp == 3: |
| 150 | + fs_c1 = int(0.1*r) |
| 151 | + fs_c2 = int(0.25*r) |
| 152 | + |
| 153 | + cnn_c1, cnn_c2, cnn_c3, cnn_c_stack = get3_comp_of_img(cnn_img, fs_c1, fs_c2, r) |
| 154 | + gt_c1, gt_c2, gt_c3, gt_c_stack = get3_comp_of_img(gt_img, fs_c1, fs_c2, r) |
| 155 | + |
| 156 | + cnn_fft_c1, cnn_fft_c2, cnn_fft_c3, cnn_fft_c_stack = get3_fft_bands_of_img(cnn_img, fs_c1, fs_c2, r) |
| 157 | + |
| 158 | + if plot_fig: |
| 159 | + pf.multi2dplots(1, 4, cnn_c_stack, axis=0, passed_fig_att={'colorbar': False}) |
| 160 | + pf.multi2dplots(1, 4, cnn_fft_c_stack, axis=0, passed_fig_att={'colorbar': False}) |
| 161 | + pf.plot2dlayers(gt_img) |
| 162 | + |
| 163 | + if save_fig: |
| 164 | + if not os.path.isdir(out_path): os.makedirs(out_path, exist_ok=True) |
| 165 | + # cnn part |
| 166 | + cnn_c1 = utils.normalize_data_ab(0, 255, cnn_c1).astype('uint8') |
| 167 | + cnn_c2 = utils.normalize_data_ab(0, 255, cnn_c2).astype('uint8') |
| 168 | + cnn_c3 = utils.normalize_data_ab(0, 255, cnn_c3).astype('uint8') |
| 169 | + |
| 170 | + #gt parts |
| 171 | + gt_c1 = utils.normalize_data_ab(0, 255, gt_c1).astype('uint8') |
| 172 | + gt_c2 = utils.normalize_data_ab(0, 255, gt_c2).astype('uint8') |
| 173 | + gt_c3 = utils.normalize_data_ab(0, 255, gt_c3).astype('uint8') |
| 174 | + #cnn fft parts |
| 175 | + cnn_fft_c1 = utils.normalize_data_ab(0, 255, cnn_fft_c1).astype('uint8') |
| 176 | + cnn_fft_c2 = utils.normalize_data_ab(0, 255, cnn_fft_c2).astype('uint8') |
| 177 | + cnn_fft_c3 = utils.normalize_data_ab(0, 255, cnn_fft_c3).astype('uint8') |
| 178 | + cnn_fft = utils.normalize_data_ab(0, 255, cnn_fft_c_stack[0]).astype('uint8') |
| 179 | + io_func.imsave(cnn_c1, path=out_path + 'cnn_c1.png', svtype='original') |
| 180 | + io_func.imsave(cnn_c2, path=out_path + 'cnn_c2.png', svtype='original') |
| 181 | + io_func.imsave(cnn_c3, path=out_path + 'cnn_c3.png', svtype='original') |
| 182 | + |
| 183 | + io_func.imsave(gt_c1, path=out_path + 'gt_c1.png', svtype='original') |
| 184 | + io_func.imsave(gt_c2, path=out_path + 'gt_c2.png', svtype='original') |
| 185 | + io_func.imsave(gt_c3, path=out_path + 'gt_c3.png', svtype='original') |
| 186 | + |
| 187 | + io_func.imsave(cnn_fft_c1, path=out_path + 'cnn_fft_c1.png', svtype='original') |
| 188 | + io_func.imsave(cnn_fft_c2, path=out_path + 'cnn_fft_c2.png', svtype='original') |
| 189 | + io_func.imsave(cnn_fft_c3, path=out_path + 'cnn_fft_c3.png', svtype='original') |
| 190 | + io_func.imsave(cnn_fft, path=out_path + 'cnn_fft.png', svtype='original') |
| 191 | + |
| 192 | +elif Ncomp==5: |
| 193 | + fs_c1 = int(0.1*r) |
| 194 | + fs_c2 = int(0.25*r) |
| 195 | + fs_c3 = int(0.5*r) |
| 196 | + fs_c4 = int(0.75*r) |
| 197 | + |
| 198 | + cnn_c1, cnn_c2, cnn_c3, cnn_c4, cnn_c5, cnn_c_stack = get5_comp_of_img(cnn_img, fs_c1, fs_c2, fs_c3, fs_c4, r) # missing-wedge-img/srgan-patch |
| 199 | + gt_c1, gt_c2, gt_c3, gt_c4, gt_c5, gt_c_stack = get5_comp_of_img(gt_img, fs_c1, fs_c2, fs_c3, fs_c4, r) |
| 200 | + cnn_fft_c1, cnn_fft_c2, cnn_fft_c3, cnn_fft_c4, cnn_fft_c5, cnn_fft_c_stack = get5_fft_bands_of_img(cnn_img, fs_c1, fs_c2, fs_c3, fs_c4, r) |
| 201 | + |
| 202 | + if not crop_fig: |
| 203 | + dist_c1, dist_c2, dist_c3, dist_c4, dist_c5, dist_c_stack = get5_comp_of_img(dist_img, fs_c1, fs_c2, fs_c3, fs_c4, r) |
| 204 | + |
| 205 | + if plot_fig: |
| 206 | + if crop_fig: |
| 207 | + # the three rows correspond to FFT of SRGAN, SRGAN and FBP |
| 208 | + band_stacks = np.stack((cnn_fft_c_stack.reshape(6, h, w), cnn_c_stack.reshape(6, h, w), gt_c_stack.reshape(6, h, w)), axis=0) |
| 209 | + print('shape of the subplots:', (band_stacks.reshape(18, h, w)).shape) |
| 210 | + pf.multi2dplots(3, 6, band_stacks.reshape(18, h, w), axis=0, passed_fig_att={'colorbar': False, 'figsize': [8, 6]})#, 'out_path': out_path + 'all_bands.png'}) |
| 211 | + else: |
| 212 | + band_stacks = np.stack((cnn_fft_c_stack.reshape(6, h, w), cnn_c_stack.reshape(6, h, w), dist_c_stack.reshape(6, h, w), gt_c_stack.reshape(6, h, w)), axis=0) |
| 213 | + print('shape of the subplots:', (band_stacks.reshape(24, h, w)).shape) |
| 214 | + pf.multi2dplots(4, 6, band_stacks.reshape(24, h, w), axis=0, passed_fig_att={'colorbar': False, 'figsize': [8, 6]})#, 'out_path': out_path + 'all_bands.png'}) |
| 215 | + |
| 216 | + if save_fig: |
| 217 | + if not os.path.isdir(out_path): os.makedirs(out_path, exist_ok=True) |
| 218 | + # cnn part |
| 219 | + cnn_c1 = utils.normalize_data_ab(0, 255, cnn_c1).astype('uint8') |
| 220 | + cnn_c2 = utils.normalize_data_ab(0, 255, cnn_c2).astype('uint8') |
| 221 | + cnn_c3 = utils.normalize_data_ab(0, 255, cnn_c3).astype('uint8') |
| 222 | + cnn_c4 = utils.normalize_data_ab(0, 255, cnn_c4).astype('uint8') |
| 223 | + cnn_c5 = utils.normalize_data_ab(0, 255, cnn_c5).astype('uint8') |
| 224 | + |
| 225 | + #gt parts |
| 226 | + gt_c1 = utils.normalize_data_ab(0, 255, gt_c1).astype('uint8') |
| 227 | + gt_c2 = utils.normalize_data_ab(0, 255, gt_c2).astype('uint8') |
| 228 | + gt_c3 = utils.normalize_data_ab(0, 255, gt_c3).astype('uint8') |
| 229 | + gt_c4 = utils.normalize_data_ab(0, 255, gt_c4).astype('uint8') |
| 230 | + gt_c5 = utils.normalize_data_ab(0, 255, gt_c5).astype('uint8') |
| 231 | + |
| 232 | + #cnn fft parts |
| 233 | + cnn_fft_c1 = utils.normalize_data_ab(0, 255, cnn_fft_c1).astype('uint8') |
| 234 | + cnn_fft_c2 = utils.normalize_data_ab(0, 255, cnn_fft_c2).astype('uint8') |
| 235 | + cnn_fft_c3 = utils.normalize_data_ab(0, 255, cnn_fft_c3).astype('uint8') |
| 236 | + cnn_fft_c4 = utils.normalize_data_ab(0, 255, cnn_fft_c4).astype('uint8') |
| 237 | + cnn_fft_c5 = utils.normalize_data_ab(0, 255, cnn_fft_c5).astype('uint8') |
| 238 | + cnn_fft = utils.normalize_data_ab(0, 255, cnn_fft_c_stack[0]).astype('uint8') |
| 239 | + |
| 240 | + io_func.imsave(cnn_c1, path=out_path + 'cnn_c1.png', svtype='original') |
| 241 | + io_func.imsave(cnn_c2, path=out_path + 'cnn_c2.png', svtype='original') |
| 242 | + io_func.imsave(cnn_c3, path=out_path + 'cnn_c3.png', svtype='original') |
| 243 | + io_func.imsave(cnn_c4, path=out_path + 'cnn_c4.png', svtype='original') |
| 244 | + io_func.imsave(cnn_c5, path=out_path + 'cnn_c5.png', svtype='original') |
| 245 | + |
| 246 | + io_func.imsave(gt_c1, path=out_path + 'gt_c1.png', svtype='original') |
| 247 | + io_func.imsave(gt_c2, path=out_path + 'gt_c2.png', svtype='original') |
| 248 | + io_func.imsave(gt_c3, path=out_path + 'gt_c3.png', svtype='original') |
| 249 | + io_func.imsave(gt_c4, path=out_path + 'gt_c4.png', svtype='original') |
| 250 | + io_func.imsave(gt_c5, path=out_path + 'gt_c5.png', svtype='original') |
| 251 | + |
| 252 | + io_func.imsave(cnn_fft_c1, path=out_path + 'cnn_fft_c1.png', svtype='original') |
| 253 | + io_func.imsave(cnn_fft_c2, path=out_path + 'cnn_fft_c2.png', svtype='original') |
| 254 | + io_func.imsave(cnn_fft_c3, path=out_path + 'cnn_fft_c3.png', svtype='original') |
| 255 | + io_func.imsave(cnn_fft_c4, path=out_path + 'cnn_fft_c4.png', svtype='original') |
| 256 | + io_func.imsave(cnn_fft_c5, path=out_path + 'cnn_fft_c5.png', svtype='original') |
| 257 | + io_func.imsave(cnn_fft, path=out_path + 'cnn_fft.png', svtype='original') |
| 258 | +else: |
| 259 | + sys.exit('the only available options are 3 and 5') |
| 260 | + |
0 commit comments