3232if cupy_run :
3333 from cupyx .scipy .ndimage import median_filter , binary_dilation , uniform_filter1d
3434 from cupyx .scipy .fft import fft2 , ifft2 , fftshift
35+ from cupyx .scipy .fftpack import get_fft_plan
3536 from httomolibgpu .cuda_kernels import load_cuda_module
3637else :
3738 median_filter = Mock ()
@@ -364,8 +365,8 @@ def _conv_transpose2d(
364365 wi = (wo - 1 ) * stride [1 ] + wk
365366 out_shape = [b , ci , hi , wi ]
366367 if mem_stack :
367- # The trouble here is that we allocate more than the returned size
368- mem_stack .malloc (( np . prod ( out_shape ) + w .size ) * np .float32 ().itemsize )
368+ mem_stack . malloc ( np . prod ( out_shape ) * np . float32 (). itemsize )
369+ mem_stack .malloc (w .size * np .float32 ().itemsize )
369370 if pad != 0 :
370371 new_out_shape = [
371372 out_shape [0 ],
@@ -374,8 +375,9 @@ def _conv_transpose2d(
374375 out_shape [3 ] - 2 * pad [1 ],
375376 ]
376377 mem_stack .malloc (np .prod (new_out_shape ) * np .float32 ().itemsize )
377- mem_stack .free (( np .prod (out_shape ) + w . size ) * np .float32 ().itemsize )
378+ mem_stack .free (np .prod (out_shape ) * np .float32 ().itemsize )
378379 out_shape = new_out_shape
380+ mem_stack .free (w .size * np .float32 ().itemsize )
379381 return out_shape
380382
381383 out = cp .zeros (out_shape , dtype = "float32" )
@@ -682,13 +684,17 @@ def remove_stripe_fw(
682684 mem_stack .malloc (fcV_bytes )
683685
684686 # For the FFT
687+ mem_stack .malloc (2 * np .prod (fcV_shape ) * np .float32 ().itemsize )
685688 mem_stack .malloc (2 * fcV_bytes )
686- # This is "leaked" by the FFT
687- if fcV_shape [1 ] > 150 :
688- fcV_fft_bytes = fcV_bytes
689- else :
690- fcV_fft_bytes = fcV_shape [0 ] * fcV_shape [2 ] * np .complex64 ().itemsize
691- mem_stack .malloc (fcV_fft_bytes )
689+
690+ fft_dummy = cp .empty (fcV_shape , dtype = 'float32' )
691+ fft_plan = get_fft_plan (fft_dummy )
692+ fft_plan_size = fft_plan .work_area .mem .size
693+ del fft_dummy
694+ del fft_plan
695+ mem_stack .malloc (fft_plan_size )
696+ mem_stack .free (2 * np .prod (fcV_shape ) * np .float32 ().itemsize )
697+ mem_stack .free (fft_plan_size )
692698 mem_stack .free (2 * fcV_bytes )
693699
694700 # The rest of the iteration doesn't contribute to the peak
@@ -699,26 +705,37 @@ def remove_stripe_fw(
699705 new_sli_shape = ifm .apply ((new_sli_shape , cc [k ]), mem_stack )
700706 mem_stack .free (np .prod (sli_shape ) * np .float32 ().itemsize )
701707 sli_shape = new_sli_shape
708+
709+ mem_stack .malloc (np .prod (data ) * np .float32 ().itemsize )
702710 for c in cc :
703711 mem_stack .free (np .prod (c ) * np .float32 ().itemsize )
704- mem_stack .malloc (np .prod (data ) * np .float32 ().itemsize )
705712 mem_stack .free (np .prod (sli_shape ) * np .float32 ().itemsize )
706- return int ( mem_stack .highwater * 1.1 )
713+ return mem_stack .highwater
707714
708715 sli = cp .zeros (sli_shape , dtype = "float32" )
709716 sli [:, 0 , (nproj_pad - nproj ) // 2 : (nproj_pad + nproj ) // 2 ] = data .swapaxes (0 , 1 )
710717 for k in range (level ):
711718 sli , c = xfm .apply (sli )
712719 cc .append (c )
713720 # FFT
714- fcV = cp .fft .fft (cc [k ][:, 0 , 1 ], axis = 1 )
721+ fft_in = cp .ascontiguousarray (cc [k ][:, 0 , 1 ])
722+ fft_plan = get_fft_plan (fft_in , axes = 1 )
723+ with fft_plan :
724+ fcV = cp .fft .fft (fft_in , axis = 1 )
725+ del fft_plan
726+ del fft_in
715727 _ , my , mx = fcV .shape
716728 # Damping of ring artifact information.
717729 y_hat = np .fft .ifftshift ((np .arange (- my , my , 2 ) + 1 ) / 2 )
718730 damp = - np .expm1 (- (y_hat ** 2 ) / (2 * sigma ** 2 ))
719731 fcV *= cp .tile (damp , (mx , 1 )).swapaxes (0 , 1 )
720732 # Inverse FFT.
721- cc [k ][:, 0 , 1 ] = cp .fft .ifft (fcV , my , axis = 1 ).real
733+ ifft_in = cp .ascontiguousarray (fcV )
734+ ifft_plan = get_fft_plan (ifft_in , axes = 1 )
735+ with ifft_plan :
736+ cc [k ][:, 0 , 1 ] = cp .fft .ifft (ifft_in , my , axis = 1 ).real
737+ del ifft_plan
738+ del ifft_in
722739
723740 # Wavelet reconstruction.
724741 for k in range (level )[::- 1 ]:
0 commit comments