Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ def _calc_memory_bytes_FBP3d_tomobar(
dtype: np.dtype,
**kwargs,
) -> Tuple[int, int]:
det_height = non_slice_dims_shape[0]
det_width = non_slice_dims_shape[1]
if "detector_pad" in kwargs:
detector_pad = kwargs["detector_pad"]
else:
detector_pad = 0

angles_tot = non_slice_dims_shape[0]
det_width = non_slice_dims_shape[1] + 2 * detector_pad
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold

# 1. input
input_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize
input_slice_size = (angles_tot * det_width) * dtype.itemsize

########## FFT / filter / IFFT (filtersync_cupy)

Expand All @@ -91,13 +96,13 @@ def _calc_memory_bytes_FBP3d_tomobar(
cufft_estimate_1d(
nx=det_width,
fft_type=CufftType.CUFFT_R2C,
batch=det_height * SLICES,
batch=angles_tot * SLICES,
)
/ SLICES
)

# 3. RFFT output size (proj_f in code)
proj_f_slice = det_height * (det_width // 2 + 1) * np.complex64().itemsize
proj_f_slice = angles_tot * (det_width // 2 + 1) * np.complex64().itemsize

# 4. Filter size (independent of number of slices)
filter_size = (det_width // 2 + 1) * np.float32().itemsize
Expand All @@ -107,7 +112,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
cufft_estimate_1d(
nx=det_width,
fft_type=CufftType.CUFFT_C2R,
batch=det_height * SLICES,
batch=angles_tot * SLICES,
)
/ SLICES
)
Expand All @@ -123,9 +128,7 @@ def _calc_memory_bytes_FBP3d_tomobar(

# 6. we swap the axes before passing data to Astra in ToMoBAR
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/methodsDIR_CuPy.py#L135
pre_astra_input_swapaxis_slice = (
np.prod(non_slice_dims_shape) * np.float32().itemsize
)
pre_astra_input_swapaxis_slice = (angles_tot * det_width) * np.float32().itemsize

# 7. astra backprojection will generate an output array
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/astra_wrappers/astra_base.py#L524
Expand All @@ -151,7 +154,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
# so it does not add to the memory overall

# We assume for safety here that one FFT plan is not freed and one is freed
tot_memory_bytes = (
tot_memory_bytes = int(
projection_mem_size + filtersync_size - ifftplan_slice_size + recon_output_size
)

Expand All @@ -166,8 +169,14 @@ def _calc_memory_bytes_LPRec3d_tomobar(
) -> Tuple[int, int]:
# Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39

if "detector_pad" in kwargs:
detector_pad = kwargs["detector_pad"]
else:
detector_pad = 0

angles_tot = non_slice_dims_shape[0]
DetectorsLengthH = non_slice_dims_shape[1]
DetectorsLengthH_prepad = non_slice_dims_shape[1]
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
_CENTER_SIZE_MIN = 192 # must be divisible by 8

Expand Down Expand Up @@ -210,7 +219,7 @@ def _calc_memory_bytes_LPRec3d_tomobar(
if odd_horiz:
output_dims = tuple(x + 1 for x in output_dims)

in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize
in_slice_size = (angles_tot * DetectorsLengthH) * dtype.itemsize
padded_in_slice_size = angles_tot * n * np.float32().itemsize

theta_size = angles_tot * np.float32().itemsize
Expand Down Expand Up @@ -256,7 +265,9 @@ def _calc_memory_bytes_LPRec3d_tomobar(
center_size * center_size * (1 + angle_range_pi_count * 2) * np.int16().itemsize
)

recon_output_size = DetectorsLengthH * DetectorsLengthH * np.float32().itemsize
recon_output_size = (
DetectorsLengthH_prepad * DetectorsLengthH_prepad * np.float32().itemsize
)
ifft2_plan_slice_size = (
cufft_estimate_2d(
nx=(2 * m + 2 * n), ny=(2 * m + 2 * n), fft_type=CufftType.CUFFT_C2C
Expand Down Expand Up @@ -342,24 +353,28 @@ def add_to_memory_counters(amount, per_slice: bool):
add_to_memory_counters(after_recon_swapaxis_slice, True)

return (tot_memory_bytes * 1.05, fixed_amount + 250 * 1024 * 1024)
# return (tot_memory_bytes, fixed_amount)


def _calc_memory_bytes_SIRT3d_tomobar(
non_slice_dims_shape: Tuple[int, int],
dtype: np.dtype,
**kwargs,
) -> Tuple[int, int]:
DetectorsLengthH = non_slice_dims_shape[1]

if "detector_pad" in kwargs:
detector_pad = kwargs["detector_pad"]
else:
detector_pad = 0
anglesnum = non_slice_dims_shape[0]
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
# calculate the output shape
output_dims = _calc_output_dim_SIRT3d_tomobar(non_slice_dims_shape, **kwargs)

in_data_size = np.prod(non_slice_dims_shape) * dtype.itemsize
in_data_size = (anglesnum * DetectorsLengthH) * dtype.itemsize
out_data_size = np.prod(output_dims) * dtype.itemsize

astra_projection = 2.5 * (in_data_size + out_data_size)

tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
tot_memory_bytes = int(2 * in_data_size + 2 * out_data_size + astra_projection)
return (tot_memory_bytes, 0)


Expand All @@ -368,14 +383,20 @@ def _calc_memory_bytes_CGLS3d_tomobar(
dtype: np.dtype,
**kwargs,
) -> Tuple[int, int]:
DetectorsLengthH = non_slice_dims_shape[1]
if "detector_pad" in kwargs:
detector_pad = kwargs["detector_pad"]
else:
detector_pad = 0

anglesnum = non_slice_dims_shape[0]
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
# calculate the output shape
output_dims = _calc_output_dim_CGLS3d_tomobar(non_slice_dims_shape, **kwargs)

in_data_size = np.prod(non_slice_dims_shape) * dtype.itemsize
in_data_size = (anglesnum * DetectorsLengthH) * dtype.itemsize
out_data_size = np.prod(output_dims) * dtype.itemsize

astra_projection = 2.5 * (in_data_size + out_data_size)

tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
tot_memory_bytes = int(2 * in_data_size + 2 * out_data_size + astra_projection)
return (tot_memory_bytes, 0)
77 changes: 56 additions & 21 deletions tests/test_httomolibgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,17 +528,24 @@ def test_data_sampler_memoryhook(slices, newshape, interpolation, ensure_clean_m


@pytest.mark.cupy
@pytest.mark.parametrize("padding_detx", [0, 10, 100, 200])
@pytest.mark.parametrize("projections", [1801, 3601])
@pytest.mark.parametrize("slices", [7, 11, 15])
@pytest.mark.parametrize("detectorX", [1200, 2560])
def test_recon_FBP3d_tomobar_memoryhook(
slices, detectorX, projections, ensure_clean_memory, mocker: MockerFixture
slices,
detectorX,
projections,
padding_detx,
ensure_clean_memory,
mocker: MockerFixture,
):
data = cp.random.random_sample((projections, slices, detectorX), dtype=np.float32)
kwargs = {}
kwargs["angles"] = np.linspace(
0.0 * np.pi / 180.0, 180.0 * np.pi / 180.0, data.shape[0]
)
kwargs["detector_pad"] = padding_detx
kwargs["center"] = 500
kwargs["recon_size"] = detectorX
kwargs["recon_mask_radius"] = 0.8
Expand Down Expand Up @@ -579,61 +586,88 @@ def test_recon_FBP3d_tomobar_memoryhook(


@pytest.mark.cupy
# @pytest.mark.parametrize("projections", [1801])
# @pytest.mark.parametrize("detX_size", [2560])
# @pytest.mark.parametrize("slices", [15])
# @pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])


@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
@pytest.mark.parametrize("projections", [1500, 1801, 2560])
@pytest.mark.parametrize("detX_size", [2560])
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
@pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])

# @pytest.mark.parametrize("projections", [1500, 1801, 2560])
# @pytest.mark.parametrize("detX_size", [2560])
# @pytest.mark.parametrize("slices", [3, 4, 5, 10])
# @pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
def test_recon_LPRec3d_tomobar_0_pi_memoryhook(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
):
__test_recon_LPRec3d_tomobar_memoryhook_common(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
)


@pytest.mark.full
@pytest.mark.cupy
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
@pytest.mark.parametrize("detX_size", [2560])
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
@pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
def test_recon_LPRec3d_tomobar_0_pi_memoryhook_full(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
):
__test_recon_LPRec3d_tomobar_memoryhook_common(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
)


@pytest.mark.full
@pytest.mark.cupy
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
@pytest.mark.parametrize("detX_size", [2560])
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
@pytest.mark.parametrize(
"projection_angle_range", [(0, np.pi), (0, 2 * np.pi), (-np.pi / 2, np.pi / 2)]
)
def test_recon_LPRec3d_tomobar_memoryhook_full(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
):
__test_recon_LPRec3d_tomobar_memoryhook_common(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
)


def __test_recon_LPRec3d_tomobar_memoryhook_common(
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
slices,
detX_size,
projections,
projection_angle_range,
padding_detx,
ensure_clean_memory,
):
angles_number = projections
data = cp.random.random_sample((angles_number, slices, detX_size), dtype=np.float32)
Expand All @@ -642,6 +676,7 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
projection_angle_range[0], projection_angle_range[1], data.shape[0]
)
kwargs["center"] = 1280
kwargs["detector_pad"] = padding_detx
kwargs["recon_size"] = detX_size
kwargs["recon_mask_radius"] = 0.8

Expand Down Expand Up @@ -687,9 +722,9 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
if slices <= 3:
assert percents_relative_maxmem <= 75
elif slices <= 5:
assert percents_relative_maxmem <= 60
assert percents_relative_maxmem <= 63
else:
assert percents_relative_maxmem <= 47
assert percents_relative_maxmem <= 50


@pytest.mark.cupy
Expand Down
Loading