3333 from cupyx .scipy .ndimage import shift , gaussian_filter
3434 from skimage .registration import phase_cross_correlation
3535 from cupyx .scipy .fftpack import get_fft_plan
36- from cupyx .scipy .fft import rfft2
36+ from cupyx .scipy .fft import rfft2 , fft2 , fftshift
3737else :
3838 load_cuda_module = Mock ()
3939 shift = Mock ()
4040 gaussian_filter = Mock ()
4141 phase_cross_correlation = Mock ()
4242 get_fft_plan = Mock ()
43+ fft2 = Mock ()
44+ fftshift = Mock ()
45+ fft = Mock ()
4346 rfft2 = Mock ()
4447
4548import math
5558def find_center_vo (
5659 data : cp .ndarray ,
5760 ind : Optional [int ] = None ,
58- smin : int = - 50 ,
59- smax : int = 50 ,
61+ average_radius : Optional [int ] = 0 ,
62+ cor_initialisation_value : Optional [float ] = None ,
63+ smin : int = - 100 ,
64+ smax : int = 100 ,
6065 srad : float = 6.0 ,
6166 step : float = 0.25 ,
6267 ratio : float = 0.5 ,
6368 drop : int = 20 ,
6469) -> float :
6570 """
66- Find rotation axis location (aka CoR ) using Nghia Vo's method. See the paper
71+ Find the rotation axis location (aka the centre of rotation ) using Nghia Vo's method. See the paper
6772 :cite:`vo2014reliable`.
6873
6974 Parameters
7075 ----------
7176 data : cp.ndarray
72- 3D tomographic data or a 2D sinogram as a CuPy array.
77+ 3D [angles, detY, detX] tomographic data or a 2D [angles, detX] sinogram as a CuPy array.
7378 ind : int, optional
74- Index of the slice to be used to estimate the CoR.
79+ Index of the slice to be used to estimate the CoR. If None is given, then the central sinogram will be extracted from the data array with a possible averaging, see .
80+ average_radius : int, optional
81+ Averaging multiple sinograms around the ind-indexed sinogram to improve the signal-to-noise ratio. It is recommended to keep this parameter smaller than 10.
82+ cor_initialisation_value : float, optional
83+ The initial approximation for the centre of rotation. If the value is None, use the horizontal centre of the projection/sinogram image.
7584 smin : int, optional
7685 Coarse search radius. Reference to the horizontal center of
7786 the sinogram.
@@ -91,56 +100,117 @@ def find_center_vo(
91100 Returns
92101 -------
93102 float
94- Rotation axis location.
103+ Rotation axis location with a subpixel precision .
95104 """
105+ # if 2d sinogram is given it is extended into a 3D array along the vertical dimension
96106 if data .ndim == 2 :
97107 data = cp .expand_dims (data , 1 )
98108 ind = 0
99109
100- height = data .shape [ 1 ]
110+ angles_tot , detY_size , detX_size = data .shape
101111
102112 if ind is None :
103- ind = height // 2
104- if height > 10 :
105- _sino = cp .mean (data [:, ind - 5 : ind + 5 , :], axis = 1 )
113+ ind = detY_size // 2 # middle slice index
114+ # averaging the data here to improve SNR
115+ if 2 * average_radius >= detY_size :
116+ # reduce the averaging radius
117+ average_radius = ind
118+ if ind > 0 :
119+ _sino = cp .mean (
120+ data [:, ind - average_radius : ind + average_radius , :], axis = 1
121+ )
106122 else :
107123 _sino = data [:, ind , :]
108124 else :
109125 _sino = data [:, ind , :]
110126
127+ if cor_initialisation_value is None :
128+ cor_initialisation_value = (detX_size - 1.0 ) / 2.0
129+
130+ # downsampling ratios
131+ dsp_angle = 1
132+ dsp_detX = 1
133+ if detX_size > 2000 :
134+ dsp_detX = 4
135+ if angles_tot > 2000 :
136+ dsp_angle = 2
137+
138+ start_cor = np .int16 (np .floor (1.0 * (cor_initialisation_value + smin ) / dsp_detX ))
139+ stop_cor = np .int16 (np .ceil (1.0 * (cor_initialisation_value + smax ) / dsp_detX ))
140+ fine_srange = max (srad , dsp_detX )
141+ off_set = 0.5 * dsp_detX if dsp_detX > 1 else 0.0
142+
143+ # initiate denoising
111144 _sino_cs = gaussian_filter (_sino , (3 , 1 ), mode = "reflect" )
112145 _sino_fs = gaussian_filter (_sino , (2 , 2 ), mode = "reflect" )
113146
114- if _sino .shape [0 ] * _sino .shape [1 ] > 4e6 :
115- # data is large, so downsample it before performing search for
116- # centre of rotation
117- _sino_coarse = _downsample (_sino_cs , 2 , 1 )
118- init_cen = _search_coarse (_sino_coarse , smin / 4.0 , smax / 4.0 , ratio , drop )
119- fine_cen = _search_fine (_sino_fs , srad , step , init_cen * 4.0 , ratio , drop )
120- else :
121- init_cen = _search_coarse (_sino_cs , smin , smax , ratio , drop )
122- fine_cen = _search_fine (_sino_fs , srad , step , init_cen , ratio , drop )
147+ # Downsampling by averaging along a chosen dimension
148+ if dsp_angle > 1 or dsp_detX > 1 :
149+ _sino_cs = _downsample (_sino_cs , dsp_angle , dsp_detX )
123150
124- return cp .asnumpy (fine_cen )
151+ # NOTE: the gpu implementation of _downsample kernel bellow is erroneuos (different results with each run), needs to be re-written
152+ # if dsp_angle > 1:
153+ # _sino_cs = _downsample_kernel(_sino_cs, level=dsp_angle, axis=0)
154+ # if dsp_detX > 1:
155+ # _sino_cs = _downsample_kernel(_sino_cs, level=dsp_detX, axis=1)
156+
157+ # NOTE: this is correct implementation that avoids running any CUDA kernels. The performance is suboptimal
158+ init_cen = _search_coarse (_sino_cs , start_cor , stop_cor , ratio , drop )
159+
160+ # NOTE: similar to the coarse module above, this is currently a correct function
161+ # but it is NOT using CUDA kernels written. Therefore some kernels re-writing is needed.
162+ fine_cen = _search_fine (
163+ _sino_fs , fine_srange , step , float (init_cen ) * dsp_detX + off_set , ratio , drop
164+ )
165+ cen_np = np .float32 (cp .asnumpy (fine_cen ))
166+ if cen_np == 0.0 :
167+ return cor_initialisation_value
168+ else :
169+ return cen_np
125170
126171
127172def _search_coarse (sino , smin , smax , ratio , drop ):
128173 (nrow , ncol ) = sino .shape
129174 flip_sino = cp .ascontiguousarray (cp .fliplr (sino ))
130175 comp_sino = cp .ascontiguousarray (cp .flipud (sino ))
131- mask = _create_mask (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
132176
177+ # # NOTE: gpu code here, half a mask created to avoid sinofram concatenitation and save memory?
178+ # mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
179+ # # NOTE: old GPU code for the sizes with half data
180+ # cen_fliplr = (ncol - 1.0) / 2.0
181+ # smin_clip_val = max(min(smin + cen_fliplr, ncol - 1), 0)
182+ # smin = smin_clip_val - cen_fliplr
183+ # smax_clip_val = max(min(smax + cen_fliplr, ncol - 1), 0)
184+ # smax = smax_clip_val - cen_fliplr
185+ # start_cor = ncol // 2 + smin
186+ # stop_cor = ncol // 2 + smax
187+ # list_cor = cp.arange(start_cor, stop_cor + 0.5, 0.5, dtype=cp.float32)
188+ # list_shift = 2.0 * (list_cor - cen_fliplr)
189+ # list_metric = cp.empty(list_shift.shape, dtype=cp.float32)
190+
191+ mask = _create_mask_numpy (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
192+ mask = cp .asarray (mask , dtype = cp .float32 )
133193 cen_fliplr = (ncol - 1.0 ) / 2.0
134- smin_clip_val = max (min (smin + cen_fliplr , ncol - 1 ), 0 )
135- smin = smin_clip_val - cen_fliplr
136- smax_clip_val = max (min (smax + cen_fliplr , ncol - 1 ), 0 )
137- smax = smax_clip_val - cen_fliplr
138- start_cor = ncol // 2 + smin
139- stop_cor = ncol // 2 + smax
140- list_cor = cp .arange (start_cor , stop_cor + 0.5 , 0.5 , dtype = cp .float32 )
194+ start_cor , stop_cor = np .sort ((smin , smax ))
195+ start_cor = np .int16 (np .clip (start_cor , 0 , ncol - 1 ))
196+ stop_cor = np .int16 (np .clip (stop_cor , 0 , ncol - 1 ))
197+ list_cor = cp .arange (start_cor , stop_cor + 1.0 , dtype = cp .float32 )
141198 list_shift = 2.0 * (list_cor - cen_fliplr )
142199 list_metric = cp .empty (list_shift .shape , dtype = cp .float32 )
143- _calculate_metric (list_shift , sino , flip_sino , comp_sino , mask , list_metric )
200+
201+ # NOTE: this gives a different result to the CPU code, also works with a half data and a half mask
202+ # _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, list_metric)
203+
204+ # This essentially repeats the CPU code... probably not optimal but correct
205+ sino_sino = cp .vstack ((sino , flip_sino ))
206+ for i , shift in enumerate (list_shift ):
207+ _sino = sino_sino [nrow :]
208+ _sino [...] = cp .roll (flip_sino , int (shift ), axis = 1 )
209+ if shift >= 0 :
210+ _sino [:, :shift ] = comp_sino [:, :shift ]
211+ else :
212+ _sino [:, shift :] = comp_sino [:, shift :]
213+ list_metric [i ] = cp .mean (cp .abs (fftshift (fft2 (sino_sino ))) * mask )
144214
145215 minpos = cp .argmin (list_metric )
146216 if minpos == 0 :
@@ -158,21 +228,52 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
158228
159229 flip_sino = cp .ascontiguousarray (cp .fliplr (sino ))
160230 comp_sino = cp .ascontiguousarray (cp .flipud (sino ))
161- mask = _create_mask (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
231+ mask = _create_mask_numpy (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
232+ mask = cp .asarray (mask , dtype = cp .float32 )
162233
163234 cen_fliplr = (ncol - 1.0 ) / 2.0
164- srad = max (min (abs (float (srad )), ncol / 4.0 ), 1.0 )
165- step = max (min (abs (step ), srad ), 0.1 )
166- init_cen = max (min (init_cen , ncol - srad - 1 ), srad )
167- list_cor = init_cen + cp .arange (- srad , srad + step , step , dtype = np .float32 )
235+ # NOTE: those are different to new implementation
236+ # srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
237+ # step = max(min(abs(step), srad), 0.1)
238+ srad = np .clip (np .abs (srad ), 1 , ncol // 10 - 1 )
239+ step = np .clip (np .abs (step ), 0.1 , 1.1 )
240+ init_cen = np .clip (init_cen , srad , ncol - srad - 1 )
241+ list_cor = init_cen + cp .arange (- srad , srad + step , step , dtype = cp .float32 )
168242 list_shift = 2.0 * (list_cor - cen_fliplr )
169243 list_metric = cp .empty (list_shift .shape , dtype = "float32" )
170244
171- _calculate_metric (list_shift , sino , flip_sino , comp_sino , mask , out = list_metric )
245+ for i , shift_l in enumerate (list_shift ):
246+ sino_shift = shift (flip_sino , (0 , shift_l ), order = 3 , prefilter = True )
247+ if shift_l >= 0 :
248+ shift_int = int (cp .ceil (shift_l ))
249+ sino_shift [:, :shift_int ] = comp_sino [:, :shift_int ]
250+ else :
251+ shift_int = int (cp .floor (shift_l ))
252+ sino_shift [:, shift_int :] = comp_sino [:, shift_int :]
253+ mat1 = cp .vstack ((sino , sino_shift ))
254+ list_metric [i ] = cp .mean (cp .abs (fftshift (fft2 (mat1 ))) * mask )
255+
256+ # _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
172257 cor = list_cor [cp .argmin (list_metric )]
173258 return cor
174259
175260
261+ def _create_mask_numpy (nrow , ncol , radius , drop ):
262+ du = 1.0 / ncol
263+ dv = (nrow - 1.0 ) / (nrow * 2.0 * np .pi )
264+ cen_row = np .int16 (np .ceil (nrow / 2.0 ) - 1 )
265+ cen_col = np .int16 (np .ceil (ncol / 2.0 ) - 1 )
266+ drop = min (drop , np .int16 (np .ceil (0.05 * nrow )))
267+ mask = np .zeros ((nrow , ncol ), dtype = "float32" )
268+ for i in range (nrow ):
269+ pos = np .int16 (np .round (((i - cen_row ) * dv / radius ) / du ))
270+ (pos1 , pos2 ) = np .clip (np .sort ((- pos + cen_col , pos + cen_col )), 0 , ncol - 1 )
271+ mask [i , pos1 : pos2 + 1 ] = 1.0
272+ mask [cen_row - drop : cen_row + drop + 1 , :] = 0.0
273+ mask [:, cen_col - 1 : cen_col + 2 ] = 0.0
274+ return mask
275+
276+
176277def _create_mask (nrow , ncol , radius , drop ):
177278 du = 1.0 / ncol
178279 dv = (nrow - 1.0 ) / (nrow * 2.0 * np .pi )
@@ -330,15 +431,46 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
330431 )
331432
332433
333- def _downsample (sino , level , axis ):
434+ def _downsample (image , dsp_fact0 , dsp_fact1 ):
435+ """Downsample an image by averaging.
436+
437+ Parameters
438+ ----------
439+ image : 2D array.
440+ dsp_fact0 : downsampling factor along axis 0.
441+ dsp_fact1 : downsampling factor along axis 1.
442+
443+ Returns
444+ ---------
445+ image_dsp : Downsampled image.
446+ """
447+ (height , width ) = image .shape
448+ dsp_fact0 = cp .clip (cp .int16 (dsp_fact0 ), 1 , height // 2 )
449+ dsp_fact1 = cp .clip (cp .int16 (dsp_fact1 ), 1 , width // 2 )
450+ height_dsp = height // dsp_fact0
451+ width_dsp = width // dsp_fact1
452+ if dsp_fact0 == 1 and dsp_fact1 == 1 :
453+ image_dsp = image
454+ else :
455+ image_dsp = image [0 : dsp_fact0 * height_dsp , 0 : dsp_fact1 * width_dsp ]
456+ image_dsp = (
457+ image_dsp .reshape (height_dsp , dsp_fact0 , width_dsp , dsp_fact1 )
458+ .mean (- 1 )
459+ .mean (1 )
460+ )
461+ return image_dsp
462+
463+
464+ def _downsample_kernel (sino , level , axis ):
334465 assert sino .dtype == cp .float32 , "single precision floating point input required"
335466 assert sino .flags ["C_CONTIGUOUS" ], "list_shift must be C-contiguous"
336467
337468 dx , dz = sino .shape
338469 # Determine the new size, dim, of the downsampled dimension
339- dim = int (sino .shape [axis ] / math .pow (2 , level ))
470+ # dim_new_size = int(sino.shape[axis] / math.pow(2, level))
471+ dim_new_size = int (sino .shape [axis ] / level )
340472 shape = [dx , dz ]
341- shape [axis ] = dim
473+ shape [axis ] = dim_new_size
342474 downsampled_data = cp .empty (shape , dtype = "float32" )
343475
344476 block_x = 8
0 commit comments