3030 "_calc_memory_bytes_LPRec3d_tomobar" ,
3131 "_calc_memory_bytes_SIRT3d_tomobar" ,
3232 "_calc_memory_bytes_CGLS3d_tomobar" ,
33+ "_calc_memory_bytes_FISTA3d_tomobar" ,
3334 "_calc_output_dim_FBP2d_astra" ,
3435 "_calc_output_dim_FBP3d_tomobar" ,
3536 "_calc_output_dim_LPRec3d_tomobar" ,
3637 "_calc_output_dim_SIRT3d_tomobar" ,
3738 "_calc_output_dim_CGLS3d_tomobar" ,
39+ "_calc_output_dim_FISTA3d_tomobar" ,
40+ "_calc_padding_FISTA3d_tomobar" ,
3841]
3942
4043
44+ def _calc_padding_FISTA3d_tomobar (** kwargs ) -> Tuple [int , int ]:
45+ return (5 , 5 )
46+
47+
4148def __calc_output_dim_recon (non_slice_dims_shape , ** kwargs ):
4249 """Function to calculate output dimensions for all reconstructors.
4350 The change of the dimension depends either on the user-provided "recon_size"
@@ -72,14 +79,21 @@ def _calc_output_dim_CGLS3d_tomobar(non_slice_dims_shape, **kwargs):
7279 return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
7380
7481
82+ def _calc_output_dim_FISTA3d_tomobar (non_slice_dims_shape , ** kwargs ):
83+ return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
84+
85+
7586def _calc_memory_bytes_FBP3d_tomobar (
7687 non_slice_dims_shape : Tuple [int , int ],
7788 dtype : np .dtype ,
7889 ** kwargs ,
7990) -> Tuple [int , int ]:
91+ detector_pad = 0
8092 if "detector_pad" in kwargs :
8193 detector_pad = kwargs ["detector_pad" ]
82- else :
94+ if detector_pad is True :
95+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
96+ elif detector_pad is False :
8397 detector_pad = 0
8498
8599 angles_tot = non_slice_dims_shape [0 ]
@@ -169,10 +183,13 @@ def _calc_memory_bytes_LPRec3d_tomobar(
169183) -> Tuple [int , int ]:
170184 # Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39
171185
186+ detector_pad = 0
172187 if "detector_pad" in kwargs :
173188 detector_pad = kwargs ["detector_pad" ]
174- else :
175- detector_pad = 0
189+ if detector_pad is True :
190+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
191+ elif detector_pad is False :
192+ detector_pad = 0
176193
177194 min_mem_usage_filter = False
178195 min_mem_usage_ifft2 = False
@@ -397,21 +414,48 @@ def _calc_memory_bytes_SIRT3d_tomobar(
397414 ** kwargs ,
398415) -> Tuple [int , int ]:
399416
417+ detector_pad = 0
400418 if "detector_pad" in kwargs :
401419 detector_pad = kwargs ["detector_pad" ]
402- else :
403- detector_pad = 0
420+ if detector_pad is True :
421+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
422+ elif detector_pad is False :
423+ detector_pad = 0
424+
404425 anglesnum = non_slice_dims_shape [0 ]
405- DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
426+ DetectorsLengthH_padded = non_slice_dims_shape [1 ] + 2 * detector_pad
406427 # calculate the output shape
407428 output_dims = _calc_output_dim_SIRT3d_tomobar (non_slice_dims_shape , ** kwargs )
429+ recon_data_size_original = (
430+ np .prod (output_dims ) * dtype .itemsize
431+ ) # x_rec user-defined size
432+
433+ in_data_size = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
434+
435+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
436+
437+ out_data_size = np .prod (output_dims_larger_grid ) * dtype .itemsize
408438
409- in_data_size = ( anglesnum * DetectorsLengthH ) * dtype . itemsize
410- out_data_size = np . prod ( output_dims ) * dtype . itemsize
439+ R = in_data_size
440+ C = out_data_size
411441
412- astra_projection = 2.5 * (in_data_size + out_data_size )
442+ Res = in_data_size
443+ Res_times_R = Res
444+ C_times_res = out_data_size
413445
414- tot_memory_bytes = int (2 * in_data_size + 2 * out_data_size + astra_projection )
446+ astra_projection = (in_data_size + out_data_size )
447+
448+ tot_memory_bytes = int (
449+ recon_data_size_original
450+ + in_data_size
451+ + out_data_size
452+ + R
453+ + C
454+ + Res
455+ + Res_times_R
456+ + C_times_res
457+ + astra_projection
458+ )
415459 return (tot_memory_bytes , 0 )
416460
417461
@@ -420,20 +464,100 @@ def _calc_memory_bytes_CGLS3d_tomobar(
420464 dtype : np .dtype ,
421465 ** kwargs ,
422466) -> Tuple [int , int ]:
467+ detector_pad = 0
423468 if "detector_pad" in kwargs :
424469 detector_pad = kwargs ["detector_pad" ]
425- else :
426- detector_pad = 0
470+ if detector_pad is True :
471+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
472+ elif detector_pad is False :
473+ detector_pad = 0
427474
428475 anglesnum = non_slice_dims_shape [0 ]
429- DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
476+ DetectorsLengthH_padded = non_slice_dims_shape [1 ] + 2 * detector_pad
430477 # calculate the output shape
431478 output_dims = _calc_output_dim_CGLS3d_tomobar (non_slice_dims_shape , ** kwargs )
479+ recon_data_size_original = (
480+ np .prod (output_dims ) * dtype .itemsize
481+ ) # x_rec user-defined size
482+
483+ in_data_size = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
484+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
485+ recon_data_size = (
486+ np .prod (output_dims_larger_grid ) * dtype .itemsize
487+ ) # large volume in the algorithm
488+ recon_data_size2 = recon_data_size # x_rec linearised
489+ d_recon = recon_data_size
490+ d_recon2 = d_recon # linearised, possibly a copy
491+
492+ data_r = in_data_size
493+ Ad = recon_data_size
494+ Ad2 = Ad
495+ s = data_r
496+ collection = (
497+ in_data_size
498+ + recon_data_size_original
499+ + recon_data_size
500+ + recon_data_size2
501+ + d_recon
502+ + d_recon2
503+ + data_r
504+ + Ad
505+ + Ad2
506+ + s
507+ )
508+ astra_contribution = in_data_size + recon_data_size
509+
510+ tot_memory_bytes = int (collection + astra_contribution )
511+ return (tot_memory_bytes , 0 )
512+
513+
514+ def _calc_memory_bytes_FISTA3d_tomobar (
515+ non_slice_dims_shape : Tuple [int , int ],
516+ dtype : np .dtype ,
517+ ** kwargs ,
518+ ) -> Tuple [int , int ]:
519+ detector_pad = 0
520+ if "detector_pad" in kwargs :
521+ detector_pad = kwargs ["detector_pad" ]
522+ if detector_pad is True :
523+ detector_pad = __estimate_detectorHoriz_padding (non_slice_dims_shape [1 ])
524+ elif detector_pad is False :
525+ detector_pad = 0
432526
433- in_data_size = ( anglesnum * DetectorsLengthH ) * dtype . itemsize
434- out_data_size = np . prod ( output_dims ) * dtype . itemsize
527+ anglesnum = non_slice_dims_shape [ 0 ]
528+ DetectorsLengthH_padded = non_slice_dims_shape [ 1 ] + 2 * detector_pad
435529
436- astra_projection = 2.5 * (in_data_size + out_data_size )
530+ # calculate the output shape
531+ output_dims = _calc_output_dim_FISTA3d_tomobar (non_slice_dims_shape , ** kwargs )
532+ recon_data_size_original = (
533+ np .prod (output_dims ) * dtype .itemsize
534+ ) # recon user-defined size
535+
536+ in_data_siz_pad = (anglesnum * DetectorsLengthH_padded ) * dtype .itemsize
537+ output_dims_larger_grid = (DetectorsLengthH_padded , DetectorsLengthH_padded )
538+
539+ residual_grad = in_data_siz_pad
540+ out_data_size = np .prod (output_dims_larger_grid ) * dtype .itemsize
541+ X_t = out_data_size
542+ X_old = out_data_size
543+
544+ grad_fidelity = out_data_size
545+
546+ fista_part = (
547+ recon_data_size_original
548+ + in_data_siz_pad
549+ + residual_grad
550+ + grad_fidelity
551+ + X_t
552+ + X_old
553+ + out_data_size
554+ )
555+ regul_part = 8 * np .prod (output_dims_larger_grid ) * dtype .itemsize
437556
438- tot_memory_bytes = int (2 * in_data_size + 2 * out_data_size + astra_projection )
557+ tot_memory_bytes = int (fista_part + regul_part )
439558 return (tot_memory_bytes , 0 )
559+
560+
561+ def __estimate_detectorHoriz_padding (detX_size ) -> int :
562+ det_half = detX_size // 2
563+ return int (np .sqrt (2 * (det_half ** 2 )) // 2 )
0 commit comments