@@ -172,168 +172,70 @@ def _calc_memory_bytes_LPRec(
172172 n = n - 1 # dealing with the odd horizontal detector size
173173 odd_horiz = True
174174
175- def debug_print (line_number : int , var_name : str , size_in_bytes : int ) -> str :
176- print (f"{ line_number } - { var_name } : { size_in_bytes } B / { size_in_bytes / 1024 } KB / { size_in_bytes / 1024 ** 2 } MB / { size_in_bytes * 16 / 1024 ** 2 } MB" )
177-
178175 eps = 1e-4 # accuracy of usfft
179176 mu = - np .log (eps ) / (2 * n * n )
180- m = int (
181- np .ceil (
182- 2 * n * 1 / np .pi * np .sqrt (- mu * np .log (eps ) + (mu * n ) * (mu * n ) / 4 )
183- )
184- )
177+ m = int (np .ceil (2 * n * 1 / np .pi * np .sqrt (- mu * np .log (eps ) + (mu * n ) * (mu * n ) / 4 )))
185178
186179 center_size = 6144
187180 center_size = min (center_size , n * 2 + m * 2 )
188- print (f"m: { m } " )
189- print (f"center_size: { center_size } " )
190181
191182 oversampling_level = 2 # at least 2 or larger required
192183 ne = oversampling_level * n
193184 padding_m = ne // 2 - n // 2
194- print (f"padding_m: { padding_m } " )
195185
196186 output_dims = __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
197187 if odd_horiz :
198188 output_dims = tuple (x + 1 for x in output_dims )
199- print (f"output_dims: { output_dims } " )
200189
201190 in_slice_size = np .prod (non_slice_dims_shape ) * dtype .itemsize
202- debug_print (0 , "in_slice_size" , in_slice_size )
203191 padded_in_slice_size = np .prod (non_slice_dims_shape ) * np .float32 ().itemsize
204- debug_print (246 , "padded_in_slice_size" , padded_in_slice_size )
205192 theta_size = angles_tot * np .float32 ().itemsize
206- debug_print (262 , "theta_size" , theta_size )
207193 recon_output_size = (n + 1 ) * (n + 1 ) * np .float32 ().itemsize if odd_horiz else n * n * np .float32 ().itemsize # 264
208- debug_print (281 , "recon_output_size" , recon_output_size )
209194 linspace_size = n * np .float32 ().itemsize
210- debug_print (285 , "linspace_size" , linspace_size )
211195 meshgrid_size = 2 * n * n * np .float32 ().itemsize
212- debug_print (286 , "meshgrid_size" , meshgrid_size )
213196 phi_size = 6 * n * n * np .float32 ().itemsize
214- debug_print (287 , "phi_size" , phi_size )
215197 angle_range_size = center_size * center_size * 3 * np .int32 ().itemsize
216- debug_print (293 , "angle_range_size" , angle_range_size )
217198 c1dfftshift_size = n * np .int8 ().itemsize
218- debug_print (296 , "c1dfftshift_size" , c1dfftshift_size )
219199 c2dfftshift_slice_size = 4 * n * n * np .int8 ().itemsize
220- debug_print (299 , "c2dfftshift_slice_size" , c2dfftshift_slice_size )
221200 filter_size = (n // 2 + 1 ) * np .float32 ().itemsize
222- debug_print (309 , "filter_size" , filter_size )
223201 rfftfreq_size = filter_size
224- debug_print (312 , "rfftfreq_size" , rfftfreq_size )
225202 scaled_filter_size = filter_size
226- debug_print (313 , "scaled_filter_size" , scaled_filter_size )
227203 tmp_p_input_slice = np .prod (non_slice_dims_shape ) * np .float32 ().itemsize
228- debug_print (316 , "tmp_p_input_slice" , tmp_p_input_slice )
229204 padded_tmp_p_input_slice = angles_tot * (n + padding_m * 2 ) * dtype .itemsize
230- debug_print (326 , "padded_tmp_p_input_slice" , padded_tmp_p_input_slice )
231205 rfft_result_size = padded_tmp_p_input_slice
232- debug_print (327 , "rfft_result_size" , rfft_result_size )
233206 filtered_rfft_result_size = rfft_result_size
234- debug_print (327 , "filtered_rfft_result_size" , filtered_rfft_result_size )
235207 rfft_plan_slice_size = cufft_estimate_1d (nx = (n + padding_m * 2 ),fft_type = CufftType .CUFFT_R2C ,batch = angles_tot * SLICES ) / SLICES
236- debug_print (327 , "rfft_plan_slice_size" , rfft_plan_slice_size )
237208 irfft_result_size = filtered_rfft_result_size
238- debug_print (327 , "irfft_result_size" , irfft_result_size )
239209 irfft_scratch_memory_size = filtered_rfft_result_size
240- debug_print (327 , "irfft_scratch_memory_size" , irfft_scratch_memory_size )
241210 irfft_plan_slice_size = cufft_estimate_1d (nx = (n + padding_m * 2 ),fft_type = CufftType .CUFFT_C2R ,batch = angles_tot * SLICES ) / SLICES
242- debug_print (327 , "irfft_plan_slice_size" , irfft_plan_slice_size )
243211 conversion_to_complex_size = np .prod (non_slice_dims_shape ) * np .complex64 ().itemsize / 2
244- debug_print (333 , "conversion_to_complex_size" , conversion_to_complex_size )
245212 datac_size = np .prod (non_slice_dims_shape ) * np .complex64 ().itemsize / 2
246- debug_print (333 , "datac_size" , datac_size )
247213 fde_size = (2 * m + 2 * n ) * (2 * m + 2 * n ) * np .complex64 ().itemsize / 2
248- debug_print (341 , "fde_size" , fde_size )
249214 shifted_datac_size = datac_size
250- debug_print (344 , "shifted_datac_size" , shifted_datac_size )
251215 fft_result_size = datac_size
252- debug_print (344 , "fft_result_size" , fft_result_size )
253216 backshifted_datac_size = datac_size
254- debug_print (344 , "backshifted_datac_size" , backshifted_datac_size )
255217 scaled_backshifted_datac_size = datac_size
256- debug_print (344 , "scaled_backshifted_datac_size" , scaled_backshifted_datac_size )
257218 fft_plan_slice_size = cufft_estimate_1d (nx = n ,fft_type = CufftType .CUFFT_C2C ,batch = angles_tot * SLICES ) / SLICES
258- debug_print (344 , "fft_plan_slice_size" , fft_plan_slice_size )
259219 fde_view_size = 4 * n * n * np .complex64 ().itemsize / 2
260220 shifted_fde_view_size = fde_view_size
261- debug_print (474 , "shifted_fde_view_size" , shifted_fde_view_size )
262221 ifft2_scratch_memory_size = fde_view_size
263- debug_print (474 , "ifft2_scratch_memory_size" , ifft2_scratch_memory_size )
264222 ifft2_plan_slice_size = cufft_estimate_2d (nx = (2 * n ),ny = (2 * n ),fft_type = CufftType .CUFFT_C2C ) / 2
265- debug_print (474 , "ifft2_plan_slice_size" , ifft2_plan_slice_size )
266223 fde2_size = n * n * np .complex64 ().itemsize / 2
267- debug_print (479 , "fde2_size" , fde2_size )
268224 concatenate_size = fde2_size
269- debug_print (485 , "concatenate_size" , concatenate_size )
270225 circular_mask_size = np .prod (output_dims ) / 2 * np .int64 ().itemsize * 4
271- debug_print (496 , "circular_mask_size" , circular_mask_size )
272226
273227 after_recon_swapaxis_slice = np .prod (non_slice_dims_shape ) * np .float32 ().itemsize
274- debug_print (0 , "after_recon_swapaxis_slice" , after_recon_swapaxis_slice )
275-
276- scope_sums = [
277- in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + padded_tmp_p_input_slice + rfft_result_size + filtered_rfft_result_size + irfft_result_size + irfft_scratch_memory_size
278- , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + datac_size + conversion_to_complex_size
279- , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + fde_size + datac_size + shifted_datac_size + fft_result_size + backshifted_datac_size + scaled_backshifted_datac_size
280- , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + shifted_fde_view_size + ifft2_scratch_memory_size
281- , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + fde2_size + concatenate_size
282- , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + after_recon_swapaxis_slice
283- ]
284-
285- print (f"all per slice memory estimation: {
286- in_slice_size
287- + padded_in_slice_size
288- + recon_output_size
289- + rfft_plan_slice_size
290- + irfft_plan_slice_size
291- + tmp_p_input_slice
292- + padded_tmp_p_input_slice
293- + rfft_result_size
294- + filtered_rfft_result_size
295- + irfft_result_size
296- + datac_size
297- + conversion_to_complex_size
298- + fft_plan_slice_size + fde_size + shifted_datac_size + fft_result_size + backshifted_datac_size + scaled_backshifted_datac_size
299- + ifft2_plan_slice_size + shifted_fde_view_size
300- + fde2_size + concatenate_size
301- + after_recon_swapaxis_slice
302- } " )
303-
304- print (f"all fixed memory estimation: {
305- theta_size + phi_size + linspace_size + meshgrid_size
306- + angle_range_size + c1dfftshift_size + c2dfftshift_slice_size + filter_size + rfftfreq_size + scaled_filter_size + circular_mask_size
307- } " )
308-
309- print (f"all memory estimation assuming 15 slices: {
310- in_slice_size * 15
311- + (
312- padded_in_slice_size
313- + recon_output_size
314- + rfft_plan_slice_size
315- + irfft_plan_slice_size
316- + tmp_p_input_slice
317- + padded_tmp_p_input_slice
318- + rfft_result_size
319- + filtered_rfft_result_size
320- + irfft_result_size
321- + datac_size
322- + conversion_to_complex_size
323- + fft_plan_slice_size + fde_size + shifted_datac_size + fft_result_size + backshifted_datac_size + scaled_backshifted_datac_size
324- + ifft2_plan_slice_size + shifted_fde_view_size
325- + fde2_size + concatenate_size
326- ) * 16
327- + after_recon_swapaxis_slice * 15
328- + theta_size + phi_size + linspace_size + meshgrid_size
329- + angle_range_size + c1dfftshift_size + c2dfftshift_slice_size + filter_size + rfftfreq_size + scaled_filter_size + circular_mask_size
330- } " )
331-
332- print (f"scoped_sums: { scope_sums } " )
333- tot_memory_bytes_peak = max (scope_sums )
334- tot_memory_peak_index = scope_sums .index (tot_memory_bytes_peak )
335- print (f"tot_memory_peak_index: { tot_memory_peak_index } " )
336- tot_memory_bytes = int (tot_memory_bytes_peak )
228+
229+ tot_memory_bytes = int (
230+ max (
231+ in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + padded_tmp_p_input_slice + rfft_result_size + filtered_rfft_result_size + irfft_result_size + irfft_scratch_memory_size
232+ , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + tmp_p_input_slice + datac_size + conversion_to_complex_size
233+ , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + fde_size + datac_size + shifted_datac_size + fft_result_size + backshifted_datac_size + scaled_backshifted_datac_size
234+ , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + shifted_fde_view_size + ifft2_scratch_memory_size
235+ , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + fde2_size + concatenate_size
236+ , in_slice_size + padded_in_slice_size + recon_output_size + rfft_plan_slice_size + irfft_plan_slice_size + fft_plan_slice_size + ifft2_plan_slice_size + after_recon_swapaxis_slice
237+ )
238+ )
337239
338240 fixed_amount = int (
339241 max (
@@ -343,9 +245,6 @@ def debug_print(line_number: int, var_name: str, size_in_bytes: int) -> str:
343245 )
344246 )
345247
346- print (f"tot_memory_bytes: { tot_memory_bytes } " )
347- print (f"fixed_amount: { fixed_amount } " )
348-
349248 return (tot_memory_bytes , fixed_amount )
350249
351250
0 commit comments