Skip to content

Commit 324cf22

Browse files
committed
Better memory estimation of FFT plan
1 parent 9fea21a commit 324cf22

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
if cupy_run:
3333
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
3434
from cupyx.scipy.fft import fft2, ifft2, fftshift
35+
from cupyx.scipy.fftpack import get_fft_plan
3536
from httomolibgpu.cuda_kernels import load_cuda_module
3637
else:
3738
median_filter = Mock()
@@ -364,8 +365,8 @@ def _conv_transpose2d(
364365
wi = (wo - 1) * stride[1] + wk
365366
out_shape = [b, ci, hi, wi]
366367
if mem_stack:
367-
# The trouble here is that we allocate more than the returned size
368-
mem_stack.malloc((np.prod(out_shape) + w.size) * np.float32().itemsize)
368+
mem_stack.malloc(np.prod(out_shape) * np.float32().itemsize)
369+
mem_stack.malloc(w.size * np.float32().itemsize)
369370
if pad != 0:
370371
new_out_shape = [
371372
out_shape[0],
@@ -374,8 +375,9 @@ def _conv_transpose2d(
374375
out_shape[3] - 2 * pad[1],
375376
]
376377
mem_stack.malloc(np.prod(new_out_shape) * np.float32().itemsize)
377-
mem_stack.free((np.prod(out_shape) + w.size) * np.float32().itemsize)
378+
mem_stack.free(np.prod(out_shape) * np.float32().itemsize)
378379
out_shape = new_out_shape
380+
mem_stack.free(w.size * np.float32().itemsize)
379381
return out_shape
380382

381383
out = cp.zeros(out_shape, dtype="float32")
@@ -682,13 +684,17 @@ def remove_stripe_fw(
682684
mem_stack.malloc(fcV_bytes)
683685

684686
# For the FFT
687+
mem_stack.malloc(2 * np.prod(fcV_shape) * np.float32().itemsize)
685688
mem_stack.malloc(2 * fcV_bytes)
686-
# This is "leaked" by the FFT
687-
if fcV_shape[1] > 150:
688-
fcV_fft_bytes = fcV_bytes
689-
else:
690-
fcV_fft_bytes = fcV_shape[0] * fcV_shape[2] * np.complex64().itemsize
691-
mem_stack.malloc(fcV_fft_bytes)
689+
690+
fft_dummy = cp.empty(fcV_shape, dtype='float32')
691+
fft_plan = get_fft_plan(fft_dummy)
692+
fft_plan_size = fft_plan.work_area.mem.size
693+
del fft_dummy
694+
del fft_plan
695+
mem_stack.malloc(fft_plan_size)
696+
mem_stack.free(2 * np.prod(fcV_shape) * np.float32().itemsize)
697+
mem_stack.free(fft_plan_size)
692698
mem_stack.free(2 * fcV_bytes)
693699

694700
# The rest of the iteration doesn't contribute to the peak
@@ -699,26 +705,37 @@ def remove_stripe_fw(
699705
new_sli_shape = ifm.apply((new_sli_shape, cc[k]), mem_stack)
700706
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
701707
sli_shape = new_sli_shape
708+
709+
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
702710
for c in cc:
703711
mem_stack.free(np.prod(c) * np.float32().itemsize)
704-
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
705712
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
706-
return int(mem_stack.highwater * 1.1)
713+
return mem_stack.highwater
707714

708715
sli = cp.zeros(sli_shape, dtype="float32")
709716
sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2] = data.swapaxes(0, 1)
710717
for k in range(level):
711718
sli, c = xfm.apply(sli)
712719
cc.append(c)
713720
# FFT
714-
fcV = cp.fft.fft(cc[k][:, 0, 1], axis=1)
721+
fft_in = cp.ascontiguousarray(cc[k][:, 0, 1])
722+
fft_plan = get_fft_plan(fft_in, axes=1)
723+
with fft_plan:
724+
fcV = cp.fft.fft(fft_in, axis=1)
725+
del fft_plan
726+
del fft_in
715727
_, my, mx = fcV.shape
716728
# Damping of ring artifact information.
717729
y_hat = np.fft.ifftshift((np.arange(-my, my, 2) + 1) / 2)
718730
damp = -np.expm1(-(y_hat**2) / (2 * sigma**2))
719731
fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1)
720732
# Inverse FFT.
721-
cc[k][:, 0, 1] = cp.fft.ifft(fcV, my, axis=1).real
733+
ifft_in = cp.ascontiguousarray(fcV)
734+
ifft_plan = get_fft_plan(ifft_in, axes=1)
735+
with ifft_plan:
736+
cc[k][:, 0, 1] = cp.fft.ifft(ifft_in, my, axis=1).real
737+
del ifft_plan
738+
del ifft_in
722739

723740
# Wavelet reconstruction.
724741
for k in range(level)[::-1]:

tests/test_prep/test_stripe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
101101
estimated_mem_peak = remove_stripe_fw(
102102
data.shape, level=level, wname=wname, calc_peak_gpu_mem=True
103103
)
104-
assert hook.max_mem == 0
105104

106105
assert actual_mem_peak * 0.99 <= estimated_mem_peak
107106
assert estimated_mem_peak <= actual_mem_peak * 1.3
@@ -121,7 +120,6 @@ def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory
121120
estimated_mem_peak = remove_stripe_fw(
122121
data_shape, wname=wname, level=level, calc_peak_gpu_mem=True
123122
)
124-
assert hook.max_mem == 0
125123
av_mem = cp.cuda.Device().mem_info[0]
126124
if av_mem < estimated_mem_peak:
127125
pytest.skip("Not enough GPU memory to run this test")

0 commit comments

Comments
 (0)