Skip to content

Commit 66ce152

Browse files
authored
Merge pull request #169 from DiamondLightSource/vocentering
Vocentering fixes
2 parents f497e0d + 29aff33 commit 66ce152

File tree

3 files changed

+185
-47
lines changed

3 files changed

+185
-47
lines changed

httomolibgpu/misc/morph.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,12 @@ def data_resampler(
121121
Returns:
122122
cp.ndarray: Up/Down-scaled 3D cupy array
123123
"""
124-
if data.ndim != 3:
125-
raise ValueError("only 3D data is supported")
124+
expanded = False
125+
# if 2d data is given it is extended into a 3D array along the vertical dimension
126+
if data.ndim == 2:
127+
expanded = True
128+
data = cp.expand_dims(data, 1)
129+
axis = 1
126130

127131
N, M, Z = cp.shape(data)
128132

@@ -214,4 +218,6 @@ def data_resampler(
214218
res, [newshape[0], newshape[1]], order="C"
215219
)
216220

221+
if expanded:
222+
scaled_data = cp.squeeze(scaled_data, axis=axis)
217223
return scaled_data

httomolibgpu/recon/rotation.py

Lines changed: 171 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,16 @@
3333
from cupyx.scipy.ndimage import shift, gaussian_filter
3434
from skimage.registration import phase_cross_correlation
3535
from cupyx.scipy.fftpack import get_fft_plan
36-
from cupyx.scipy.fft import rfft2
36+
from cupyx.scipy.fft import rfft2, fft2, fftshift
3737
else:
3838
load_cuda_module = Mock()
3939
shift = Mock()
4040
gaussian_filter = Mock()
4141
phase_cross_correlation = Mock()
4242
get_fft_plan = Mock()
43+
fft2 = Mock()
44+
fftshift = Mock()
45+
fft = Mock()
4346
rfft2 = Mock()
4447

4548
import math
@@ -55,23 +58,29 @@
5558
def find_center_vo(
5659
data: cp.ndarray,
5760
ind: Optional[int] = None,
58-
smin: int = -50,
59-
smax: int = 50,
61+
average_radius: Optional[int] = 0,
62+
cor_initialisation_value: Optional[float] = None,
63+
smin: int = -100,
64+
smax: int = 100,
6065
srad: float = 6.0,
6166
step: float = 0.25,
6267
ratio: float = 0.5,
6368
drop: int = 20,
6469
) -> float:
6570
"""
66-
Find rotation axis location (aka CoR) using Nghia Vo's method. See the paper
71+
Find the rotation axis location (aka the centre of rotation) using Nghia Vo's method. See the paper
6772
:cite:`vo2014reliable`.
6873
6974
Parameters
7075
----------
7176
data : cp.ndarray
72-
3D tomographic data or a 2D sinogram as a CuPy array.
77+
3D [angles, detY, detX] tomographic data or a 2D [angles, detX] sinogram as a CuPy array.
7378
ind : int, optional
74-
Index of the slice to be used to estimate the CoR.
79+
Index of the slice to be used to estimate the CoR. If None is given, then the central sinogram will be extracted from the data array with a possible averaging, see .
80+
average_radius : int, optional
81+
Averaging multiple sinograms around the ind-indexed sinogram to improve the signal-to-noise ratio. It is recommended to keep this parameter smaller than 10.
82+
cor_initialisation_value : float, optional
83+
The initial approximation for the centre of rotation. If the value is None, use the horizontal centre of the projection/sinogram image.
7584
smin : int, optional
7685
Coarse search radius. Reference to the horizontal center of
7786
the sinogram.
@@ -91,56 +100,117 @@ def find_center_vo(
91100
Returns
92101
-------
93102
float
94-
Rotation axis location.
103+
Rotation axis location with a subpixel precision.
95104
"""
105+
# if 2d sinogram is given it is extended into a 3D array along the vertical dimension
96106
if data.ndim == 2:
97107
data = cp.expand_dims(data, 1)
98108
ind = 0
99109

100-
height = data.shape[1]
110+
angles_tot, detY_size, detX_size = data.shape
101111

102112
if ind is None:
103-
ind = height // 2
104-
if height > 10:
105-
_sino = cp.mean(data[:, ind - 5 : ind + 5, :], axis=1)
113+
ind = detY_size // 2 # middle slice index
114+
# averaging the data here to improve SNR
115+
if 2 * average_radius >= detY_size:
116+
# reduce the averaging radius
117+
average_radius = ind
118+
if ind > 0:
119+
_sino = cp.mean(
120+
data[:, ind - average_radius : ind + average_radius, :], axis=1
121+
)
106122
else:
107123
_sino = data[:, ind, :]
108124
else:
109125
_sino = data[:, ind, :]
110126

127+
if cor_initialisation_value is None:
128+
cor_initialisation_value = (detX_size - 1.0) / 2.0
129+
130+
# downsampling ratios
131+
dsp_angle = 1
132+
dsp_detX = 1
133+
if detX_size > 2000:
134+
dsp_detX = 4
135+
if angles_tot > 2000:
136+
dsp_angle = 2
137+
138+
start_cor = np.int16(np.floor(1.0 * (cor_initialisation_value + smin) / dsp_detX))
139+
stop_cor = np.int16(np.ceil(1.0 * (cor_initialisation_value + smax) / dsp_detX))
140+
fine_srange = max(srad, dsp_detX)
141+
off_set = 0.5 * dsp_detX if dsp_detX > 1 else 0.0
142+
143+
# initiate denoising
111144
_sino_cs = gaussian_filter(_sino, (3, 1), mode="reflect")
112145
_sino_fs = gaussian_filter(_sino, (2, 2), mode="reflect")
113146

114-
if _sino.shape[0] * _sino.shape[1] > 4e6:
115-
# data is large, so downsample it before performing search for
116-
# centre of rotation
117-
_sino_coarse = _downsample(_sino_cs, 2, 1)
118-
init_cen = _search_coarse(_sino_coarse, smin / 4.0, smax / 4.0, ratio, drop)
119-
fine_cen = _search_fine(_sino_fs, srad, step, init_cen * 4.0, ratio, drop)
120-
else:
121-
init_cen = _search_coarse(_sino_cs, smin, smax, ratio, drop)
122-
fine_cen = _search_fine(_sino_fs, srad, step, init_cen, ratio, drop)
147+
# Downsampling by averaging along a chosen dimension
148+
if dsp_angle > 1 or dsp_detX > 1:
149+
_sino_cs = _downsample(_sino_cs, dsp_angle, dsp_detX)
123150

124-
return cp.asnumpy(fine_cen)
151+
# NOTE: the gpu implementation of _downsample kernel bellow is erroneuos (different results with each run), needs to be re-written
152+
# if dsp_angle > 1:
153+
# _sino_cs = _downsample_kernel(_sino_cs, level=dsp_angle, axis=0)
154+
# if dsp_detX > 1:
155+
# _sino_cs = _downsample_kernel(_sino_cs, level=dsp_detX, axis=1)
156+
157+
# NOTE: this is correct implementation that avoids running any CUDA kernels. The performance is suboptimal
158+
init_cen = _search_coarse(_sino_cs, start_cor, stop_cor, ratio, drop)
159+
160+
# NOTE: similar to the coarse module above, this is currently a correct function
161+
# but it is NOT using CUDA kernels written. Therefore some kernels re-writing is needed.
162+
fine_cen = _search_fine(
163+
_sino_fs, fine_srange, step, float(init_cen) * dsp_detX + off_set, ratio, drop
164+
)
165+
cen_np = np.float32(cp.asnumpy(fine_cen))
166+
if cen_np == 0.0:
167+
return cor_initialisation_value
168+
else:
169+
return cen_np
125170

126171

127172
def _search_coarse(sino, smin, smax, ratio, drop):
128173
(nrow, ncol) = sino.shape
129174
flip_sino = cp.ascontiguousarray(cp.fliplr(sino))
130175
comp_sino = cp.ascontiguousarray(cp.flipud(sino))
131-
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
132176

177+
# # NOTE: gpu code here, half a mask created to avoid sinofram concatenitation and save memory?
178+
# mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
179+
# # NOTE: old GPU code for the sizes with half data
180+
# cen_fliplr = (ncol - 1.0) / 2.0
181+
# smin_clip_val = max(min(smin + cen_fliplr, ncol - 1), 0)
182+
# smin = smin_clip_val - cen_fliplr
183+
# smax_clip_val = max(min(smax + cen_fliplr, ncol - 1), 0)
184+
# smax = smax_clip_val - cen_fliplr
185+
# start_cor = ncol // 2 + smin
186+
# stop_cor = ncol // 2 + smax
187+
# list_cor = cp.arange(start_cor, stop_cor + 0.5, 0.5, dtype=cp.float32)
188+
# list_shift = 2.0 * (list_cor - cen_fliplr)
189+
# list_metric = cp.empty(list_shift.shape, dtype=cp.float32)
190+
191+
mask = _create_mask_numpy(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
192+
mask = cp.asarray(mask, dtype=cp.float32)
133193
cen_fliplr = (ncol - 1.0) / 2.0
134-
smin_clip_val = max(min(smin + cen_fliplr, ncol - 1), 0)
135-
smin = smin_clip_val - cen_fliplr
136-
smax_clip_val = max(min(smax + cen_fliplr, ncol - 1), 0)
137-
smax = smax_clip_val - cen_fliplr
138-
start_cor = ncol // 2 + smin
139-
stop_cor = ncol // 2 + smax
140-
list_cor = cp.arange(start_cor, stop_cor + 0.5, 0.5, dtype=cp.float32)
194+
start_cor, stop_cor = np.sort((smin, smax))
195+
start_cor = np.int16(np.clip(start_cor, 0, ncol - 1))
196+
stop_cor = np.int16(np.clip(stop_cor, 0, ncol - 1))
197+
list_cor = cp.arange(start_cor, stop_cor + 1.0, dtype=cp.float32)
141198
list_shift = 2.0 * (list_cor - cen_fliplr)
142199
list_metric = cp.empty(list_shift.shape, dtype=cp.float32)
143-
_calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, list_metric)
200+
201+
# NOTE: this gives a different result to the CPU code, also works with a half data and a half mask
202+
# _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, list_metric)
203+
204+
# This essentially repeats the CPU code... probably not optimal but correct
205+
sino_sino = cp.vstack((sino, flip_sino))
206+
for i, shift in enumerate(list_shift):
207+
_sino = sino_sino[nrow:]
208+
_sino[...] = cp.roll(flip_sino, int(shift), axis=1)
209+
if shift >= 0:
210+
_sino[:, :shift] = comp_sino[:, :shift]
211+
else:
212+
_sino[:, shift:] = comp_sino[:, shift:]
213+
list_metric[i] = cp.mean(cp.abs(fftshift(fft2(sino_sino))) * mask)
144214

145215
minpos = cp.argmin(list_metric)
146216
if minpos == 0:
@@ -158,21 +228,52 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
158228

159229
flip_sino = cp.ascontiguousarray(cp.fliplr(sino))
160230
comp_sino = cp.ascontiguousarray(cp.flipud(sino))
161-
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
231+
mask = _create_mask_numpy(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
232+
mask = cp.asarray(mask, dtype=cp.float32)
162233

163234
cen_fliplr = (ncol - 1.0) / 2.0
164-
srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
165-
step = max(min(abs(step), srad), 0.1)
166-
init_cen = max(min(init_cen, ncol - srad - 1), srad)
167-
list_cor = init_cen + cp.arange(-srad, srad + step, step, dtype=np.float32)
235+
# NOTE: those are different to new implementation
236+
# srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
237+
# step = max(min(abs(step), srad), 0.1)
238+
srad = np.clip(np.abs(srad), 1, ncol // 10 - 1)
239+
step = np.clip(np.abs(step), 0.1, 1.1)
240+
init_cen = np.clip(init_cen, srad, ncol - srad - 1)
241+
list_cor = init_cen + cp.arange(-srad, srad + step, step, dtype=cp.float32)
168242
list_shift = 2.0 * (list_cor - cen_fliplr)
169243
list_metric = cp.empty(list_shift.shape, dtype="float32")
170244

171-
_calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
245+
for i, shift_l in enumerate(list_shift):
246+
sino_shift = shift(flip_sino, (0, shift_l), order=3, prefilter=True)
247+
if shift_l >= 0:
248+
shift_int = int(cp.ceil(shift_l))
249+
sino_shift[:, :shift_int] = comp_sino[:, :shift_int]
250+
else:
251+
shift_int = int(cp.floor(shift_l))
252+
sino_shift[:, shift_int:] = comp_sino[:, shift_int:]
253+
mat1 = cp.vstack((sino, sino_shift))
254+
list_metric[i] = cp.mean(cp.abs(fftshift(fft2(mat1))) * mask)
255+
256+
# _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
172257
cor = list_cor[cp.argmin(list_metric)]
173258
return cor
174259

175260

261+
def _create_mask_numpy(nrow, ncol, radius, drop):
262+
du = 1.0 / ncol
263+
dv = (nrow - 1.0) / (nrow * 2.0 * np.pi)
264+
cen_row = np.int16(np.ceil(nrow / 2.0) - 1)
265+
cen_col = np.int16(np.ceil(ncol / 2.0) - 1)
266+
drop = min(drop, np.int16(np.ceil(0.05 * nrow)))
267+
mask = np.zeros((nrow, ncol), dtype="float32")
268+
for i in range(nrow):
269+
pos = np.int16(np.round(((i - cen_row) * dv / radius) / du))
270+
(pos1, pos2) = np.clip(np.sort((-pos + cen_col, pos + cen_col)), 0, ncol - 1)
271+
mask[i, pos1 : pos2 + 1] = 1.0
272+
mask[cen_row - drop : cen_row + drop + 1, :] = 0.0
273+
mask[:, cen_col - 1 : cen_col + 2] = 0.0
274+
return mask
275+
276+
176277
def _create_mask(nrow, ncol, radius, drop):
177278
du = 1.0 / ncol
178279
dv = (nrow - 1.0) / (nrow * 2.0 * np.pi)
@@ -330,15 +431,46 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
330431
)
331432

332433

333-
def _downsample(sino, level, axis):
434+
def _downsample(image, dsp_fact0, dsp_fact1):
435+
"""Downsample an image by averaging.
436+
437+
Parameters
438+
----------
439+
image : 2D array.
440+
dsp_fact0 : downsampling factor along axis 0.
441+
dsp_fact1 : downsampling factor along axis 1.
442+
443+
Returns
444+
---------
445+
image_dsp : Downsampled image.
446+
"""
447+
(height, width) = image.shape
448+
dsp_fact0 = cp.clip(cp.int16(dsp_fact0), 1, height // 2)
449+
dsp_fact1 = cp.clip(cp.int16(dsp_fact1), 1, width // 2)
450+
height_dsp = height // dsp_fact0
451+
width_dsp = width // dsp_fact1
452+
if dsp_fact0 == 1 and dsp_fact1 == 1:
453+
image_dsp = image
454+
else:
455+
image_dsp = image[0 : dsp_fact0 * height_dsp, 0 : dsp_fact1 * width_dsp]
456+
image_dsp = (
457+
image_dsp.reshape(height_dsp, dsp_fact0, width_dsp, dsp_fact1)
458+
.mean(-1)
459+
.mean(1)
460+
)
461+
return image_dsp
462+
463+
464+
def _downsample_kernel(sino, level, axis):
334465
assert sino.dtype == cp.float32, "single precision floating point input required"
335466
assert sino.flags["C_CONTIGUOUS"], "list_shift must be C-contiguous"
336467

337468
dx, dz = sino.shape
338469
# Determine the new size, dim, of the downsampled dimension
339-
dim = int(sino.shape[axis] / math.pow(2, level))
470+
# dim_new_size = int(sino.shape[axis] / math.pow(2, level))
471+
dim_new_size = int(sino.shape[axis] / level)
340472
shape = [dx, dz]
341-
shape[axis] = dim
473+
shape[axis] = dim_new_size
342474
downsampled_data = cp.empty(shape, dtype="float32")
343475

344476
block_x = 8

tests/test_recon/test_rotation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@ def test_find_center_vo(data, flats, darks):
2222
data = normalize(data, flats, darks)
2323

2424
# --- testing the center of rotation on tomo_standard ---#
25-
cor = find_center_vo(data)
25+
cor = find_center_vo(
26+
data.copy(),
27+
average_radius=0,
28+
)
2629

2730
data = None #: free up GPU memory
2831
assert_allclose(cor, 79.5)
2932

30-
#: Check that we only get a float32 output
31-
assert cor.dtype == np.float32
32-
3333

3434
def test_find_center_vo_ones(ensure_clean_memory):
3535
mat = cp.ones(shape=(103, 450, 230), dtype=cp.float32)
3636
cor = find_center_vo(mat)
3737

38-
assert_allclose(cor, 59.0)
38+
assert_allclose(cor, 8)
3939
mat = None #: free up GPU memory
4040

4141

@@ -44,7 +44,7 @@ def test_find_center_vo_random(ensure_clean_memory):
4444
data_host = np.random.random_sample(size=(900, 1, 1280)).astype(np.float32) * 2.0
4545
data = cp.asarray(data_host, dtype=np.float32)
4646
cent = find_center_vo(data)
47-
assert_allclose(cent, 680.75)
47+
assert_allclose(cent, 550.25)
4848

4949

5050
def test_find_center_vo_big_data(sino3600):

0 commit comments

Comments
 (0)