3030 "_calc_memory_bytes_LPRec3d_tomobar" ,
3131 "_calc_memory_bytes_SIRT3d_tomobar" ,
3232 "_calc_memory_bytes_CGLS3d_tomobar" ,
33+ "_calc_memory_bytes_FISTA3d_tomobar" ,
3334 "_calc_output_dim_FBP2d_astra" ,
3435 "_calc_output_dim_FBP3d_tomobar" ,
3536 "_calc_output_dim_LPRec3d_tomobar" ,
3637 "_calc_output_dim_SIRT3d_tomobar" ,
3738 "_calc_output_dim_CGLS3d_tomobar" ,
39+ "_calc_output_dim_FISTA3d_tomobar" ,
40+ "_calc_padding_FISTA3d_tomobar" ,
3841]
3942
4043
44+ def _calc_padding_FISTA3d_tomobar (** kwargs ) -> Tuple [int , int ]:
45+ return (5 , 5 )
46+
47+
4148def __calc_output_dim_recon (non_slice_dims_shape , ** kwargs ):
4249 """Function to calculate output dimensions for all reconstructors.
4350 The change of the dimension depends either on the user-provided "recon_size"
@@ -72,14 +79,21 @@ def _calc_output_dim_CGLS3d_tomobar(non_slice_dims_shape, **kwargs):
7279 return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
7380
7481
82+ def _calc_output_dim_FISTA3d_tomobar (non_slice_dims_shape , ** kwargs ):
83+ return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
84+
85+
7586def _calc_memory_bytes_FBP3d_tomobar (
7687 non_slice_dims_shape : Tuple [int , int ],
7788 dtype : np .dtype ,
7889 ** kwargs ,
7990) -> Tuple [int , int ]:
91+ detector_pad = 0
8092 if "detector_pad" in kwargs :
8193 detector_pad = kwargs ["detector_pad" ]
82- else :
94+ if detector_pad is True :
95+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
96+ elif detector_pad is False :
8397 detector_pad = 0
8498
8599 angles_tot = non_slice_dims_shape [0 ]
@@ -169,10 +183,13 @@ def _calc_memory_bytes_LPRec3d_tomobar(
169183) -> Tuple [int , int ]:
170184 # Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39
171185
186+ detector_pad = 0
172187 if "detector_pad" in kwargs :
173188 detector_pad = kwargs ["detector_pad" ]
174- else :
175- detector_pad = 0
189+ if detector_pad is True :
190+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
191+ elif detector_pad is False :
192+ detector_pad = 0
176193
177194 min_mem_usage_filter = False
178195 min_mem_usage_ifft2 = False
@@ -388,21 +405,48 @@ def _calc_memory_bytes_SIRT3d_tomobar(
388405 ** kwargs ,
389406) -> Tuple [int , int ]:
390407
408+ detector_pad = 0
391409 if "detector_pad" in kwargs :
392410 detector_pad = kwargs ["detector_pad" ]
393- else :
394- detector_pad = 0
411+ if detector_pad is True :
412+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
413+ elif detector_pad is False :
414+ detector_pad = 0
415+
395416 anglesnum = non_slice_dims_shape [0 ]
396- DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
417+ DetectorsLengthH_padded = non_slice_dims_shape [1 ] + 2 * detector_pad
397418 # calculate the output shape
398419 output_dims = _calc_output_dim_SIRT3d_tomobar (non_slice_dims_shape , ** kwargs )
420+ recon_data_size_original = (
421+ np .prod (output_dims ) * dtype .itemsize
422+ ) # x_rec user-defined size
423+
424+ in_data_size = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
425+
426+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
427+
428+ out_data_size = np .prod (output_dims_larger_grid ) * dtype .itemsize
399429
400- in_data_size = ( anglesnum * DetectorsLengthH ) * dtype . itemsize
401- out_data_size = np . prod ( output_dims ) * dtype . itemsize
430+ R = in_data_size
431+ C = out_data_size
402432
403- astra_projection = 2.5 * (in_data_size + out_data_size )
433+ Res = in_data_size
434+ Res_times_R = Res
435+ C_times_res = out_data_size
404436
405- tot_memory_bytes = int (2 * in_data_size + 2 * out_data_size + astra_projection )
437+ astra_projection = (in_data_size + out_data_size )
438+
439+ tot_memory_bytes = int (
440+ recon_data_size_original
441+ + in_data_size
442+ + out_data_size
443+ + R
444+ + C
445+ + Res
446+ + Res_times_R
447+ + C_times_res
448+ + astra_projection
449+ )
406450 return (tot_memory_bytes , 0 )
407451
408452
@@ -411,20 +455,100 @@ def _calc_memory_bytes_CGLS3d_tomobar(
411455 dtype : np .dtype ,
412456 ** kwargs ,
413457) -> Tuple [int , int ]:
458+ detector_pad = 0
414459 if "detector_pad" in kwargs :
415460 detector_pad = kwargs ["detector_pad" ]
416- else :
417- detector_pad = 0
461+ if detector_pad is True :
462+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
463+ elif detector_pad is False :
464+ detector_pad = 0
418465
419466 anglesnum = non_slice_dims_shape [0 ]
420- DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
467+ DetectorsLengthH_padded = non_slice_dims_shape [1 ] + 2 * detector_pad
421468 # calculate the output shape
422469 output_dims = _calc_output_dim_CGLS3d_tomobar (non_slice_dims_shape , ** kwargs )
470+ recon_data_size_original = (
471+ np .prod (output_dims ) * dtype .itemsize
472+ ) # x_rec user-defined size
473+
474+ in_data_size = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
475+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
476+ recon_data_size = (
477+ np .prod (output_dims_larger_grid ) * dtype .itemsize
478+ ) # large volume in the algorithm
479+ recon_data_size2 = recon_data_size # x_rec linearised
480+ d_recon = recon_data_size
481+ d_recon2 = d_recon # linearised, possibly a copy
482+
483+ data_r = in_data_size
484+ Ad = recon_data_size
485+ Ad2 = Ad
486+ s = data_r
487+ collection = (
488+ in_data_size
489+ + recon_data_size_original
490+ + recon_data_size
491+ + recon_data_size2
492+ + d_recon
493+ + d_recon2
494+ + data_r
495+ + Ad
496+ + Ad2
497+ + s
498+ )
499+ astra_contribution = in_data_size + recon_data_size
500+
501+ tot_memory_bytes = int (collection + astra_contribution )
502+ return (tot_memory_bytes , 0 )
503+
504+
505+ def _calc_memory_bytes_FISTA3d_tomobar (
506+ non_slice_dims_shape : Tuple [int , int ],
507+ dtype : np .dtype ,
508+ ** kwargs ,
509+ ) -> Tuple [int , int ]:
510+ detector_pad = 0
511+ if "detector_pad" in kwargs :
512+ detector_pad = kwargs ["detector_pad" ]
513+ if detector_pad is True :
514+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
515+ elif detector_pad is False :
516+ detector_pad = 0
423517
424- in_data_size = ( anglesnum * DetectorsLengthH ) * dtype . itemsize
425- out_data_size = np . prod ( output_dims ) * dtype . itemsize
518+ anglesnum = non_slice_dims_shape [ 0 ]
519+ DetectorsLengthH_padded = non_slice_dims_shape [ 1 ] + 2 * detector_pad
426520
427- astra_projection = 2.5 * (in_data_size + out_data_size )
521+ # calculate the output shape
522+ output_dims = _calc_output_dim_FISTA3d_tomobar (non_slice_dims_shape , ** kwargs )
523+ recon_data_size_original = (
524+ np .prod (output_dims ) * dtype .itemsize
525+ ) # recon user-defined size
526+
527+ in_data_siz_pad = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
528+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
529+
530+ residual_grad = in_data_siz_pad
531+ out_data_size = np .prod (output_dims_larger_grid ) * dtype .itemsize
532+ X_t = out_data_size
533+ X_old = out_data_size
534+
535+ grad_fidelity = out_data_size
536+
537+ fista_part = (
538+ recon_data_size_original
539+ + in_data_siz_pad
540+ + residual_grad
541+ + grad_fidelity
542+ + X_t
543+ + X_old
544+ + out_data_size
545+ )
546+ regul_part = 8 * np .prod (output_dims_larger_grid ) * dtype .itemsize
428547
429- tot_memory_bytes = int (2 * in_data_size + 2 * out_data_size + astra_projection )
548+ tot_memory_bytes = int (fista_part + regul_part )
430549 return (tot_memory_bytes , 0 )
550+
551+
552+ def __estimate_detectorHoriz_padding (detX_size ) -> int :
553+ det_half = detX_size // 2
554+ return int (np .sqrt (2 * (det_half ** 2 )) // 2 )
0 commit comments