Skip to content

Commit c7fec76

Browse files
authored
Merge pull request #75 from DiamondLightSource/projection_chunking
Bring LPRec the memory estimator and its tests up to date with the latest ToMoBAR
2 parents 2b7efb2 + ba2f195 commit c7fec76

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

httomo_backends/methods_database/packages/backends/httomolibgpu/supporting_funcs/recon/algorithm.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,19 @@ def _calc_memory_bytes_LPRec3d_tomobar(
218218
center_size = 32768
219219
center_size = min(center_size, n * 2)
220220

221-
oversampling_level = 2 # at least 2 or larger required
222-
ne = oversampling_level * n
221+
chunk_count = 4
222+
projection_chunk_count = 4
223+
oversampling_level = 4 # at least 3 or larger required
224+
power_of_2_oversampling = True
225+
226+
if power_of_2_oversampling:
227+
ne = 2 ** math.ceil(math.log2(DetectorsLengthH_prepad * 3))
228+
if n > ne:
229+
ne = 2 ** math.ceil(math.log2(n))
230+
else:
231+
ne = int(oversampling_level * DetectorsLengthH_prepad)
232+
ne = max(ne, n)
233+
223234
padding_m = ne // 2 - n // 2
224235

225236
if "angles" in kwargs:
@@ -233,8 +244,6 @@ def _calc_memory_bytes_LPRec3d_tomobar(
233244
np.ceil(2)
234245
) # assume a 2 * PI projection angle range
235246

236-
chunk_count = 4
237-
238247
output_dims = __calc_output_dim_recon(non_slice_dims_shape, **kwargs)
239248
if odd_horiz:
240249
output_dims = tuple(x + 1 for x in output_dims)
@@ -346,23 +355,23 @@ def add_to_memory_counters(amount, per_slice: bool):
346355

347356
add_to_memory_counters(-irfft_result_size, False)
348357
else:
349-
add_to_memory_counters(rfft_plan_slice_size / chunk_count * 2, True)
350-
add_to_memory_counters(irfft_plan_slice_size / chunk_count * 2, True)
351-
# add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
358+
add_to_memory_counters(rfft_plan_slice_size / chunk_count / projection_chunk_count * 2, True)
359+
add_to_memory_counters(irfft_plan_slice_size / chunk_count / projection_chunk_count * 2, True)
360+
# add_to_memory_counters(irfft_scratch_memory_size / chunk_count / projection_chunk_count, True)
352361
for _ in range(0, chunk_count):
353-
add_to_memory_counters(padded_tmp_p_input_slice / chunk_count, True)
362+
add_to_memory_counters(padded_tmp_p_input_slice / chunk_count / projection_chunk_count, True)
354363

355-
add_to_memory_counters(rfft_result_size / chunk_count, True)
356-
add_to_memory_counters(filtered_rfft_result_size / chunk_count, True)
357-
add_to_memory_counters(-rfft_result_size / chunk_count, True)
358-
add_to_memory_counters(-padded_tmp_p_input_slice / chunk_count, True)
364+
add_to_memory_counters(rfft_result_size / chunk_count / projection_chunk_count, True)
365+
add_to_memory_counters(filtered_rfft_result_size / chunk_count / projection_chunk_count, True)
366+
add_to_memory_counters(-rfft_result_size / chunk_count / projection_chunk_count, True)
367+
add_to_memory_counters(-padded_tmp_p_input_slice / chunk_count / projection_chunk_count, True)
359368

360-
add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
361-
add_to_memory_counters(-irfft_scratch_memory_size / chunk_count, True)
362-
add_to_memory_counters(irfft_result_size / chunk_count, True)
363-
add_to_memory_counters(-filtered_rfft_result_size / chunk_count, True)
369+
add_to_memory_counters(irfft_scratch_memory_size / chunk_count / projection_chunk_count, True)
370+
add_to_memory_counters(-irfft_scratch_memory_size / chunk_count / projection_chunk_count, True)
371+
add_to_memory_counters(irfft_result_size / chunk_count / projection_chunk_count, True)
372+
add_to_memory_counters(-filtered_rfft_result_size / chunk_count / projection_chunk_count, True)
364373

365-
add_to_memory_counters(-irfft_result_size / chunk_count, True)
374+
add_to_memory_counters(-irfft_result_size / chunk_count / projection_chunk_count, True)
366375

367376
add_to_memory_counters(-padded_in_slice_size, True)
368377
add_to_memory_counters(-filter_size, False)
@@ -396,7 +405,7 @@ def add_to_memory_counters(amount, per_slice: bool):
396405
if min_mem_usage_ifft2 and min_mem_usage_filter:
397406
return (tot_memory_bytes * 1.1 + 30 * 1024 * 1024, fixed_amount)
398407
else:
399-
return (tot_memory_bytes, fixed_amount)
408+
return (tot_memory_bytes * 1.1, fixed_amount)
400409

401410

402411
def _calc_memory_bytes_SIRT3d_tomobar(
@@ -551,4 +560,6 @@ def _calc_memory_bytes_FISTA3d_tomobar(
551560

552561
def __estimate_detectorHoriz_padding(detX_size) -> int:
553562
det_half = detX_size // 2
554-
return int(np.sqrt(2 * (det_half**2)) // 2)
563+
padded_value_exact = int(np.sqrt(2 * (det_half**2))) - det_half
564+
padded_add_margin = int(0.1 * padded_value_exact)
565+
return padded_value_exact + padded_add_margin

tests/test_httomolibgpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook(
568568

569569
@pytest.mark.full
570570
@pytest.mark.cupy
571-
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
571+
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100, 800])
572572
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
573573
@pytest.mark.parametrize("detX_size", [2560])
574574
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
@@ -593,7 +593,7 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook_full(
593593

594594
@pytest.mark.full
595595
@pytest.mark.cupy
596-
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
596+
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100, 800])
597597
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
598598
@pytest.mark.parametrize("detX_size", [2560])
599599
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
@@ -676,7 +676,7 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
676676
# the estimated_memory_mb should be LARGER or EQUAL to max_mem_mb
677677
# the resulting percent value should not deviate from max_mem on more than 20%
678678
assert estimated_memory_mb >= max_mem_mb
679-
assert percents_relative_maxmem <= 90
679+
assert percents_relative_maxmem <= 60
680680

681681

682682
@pytest.mark.cupy

0 commit comments

Comments
 (0)