Skip to content

Commit 5bbd9ef

Browse files
authored
Merge pull request #77 from DiamondLightSource/projection_chunking
Projection chunking
2 parents 9c629f9 + 81f2cd6 commit 5bbd9ef

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

httomo_backends/methods_database/packages/backends/httomolibgpu/supporting_funcs/recon/algorithm.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ def _calc_memory_bytes_LPRec3d_tomobar(
192192
detector_pad = 0
193193

194194
min_mem_usage_filter = False
195+
if "min_mem_usage_filter" in kwargs:
196+
min_mem_usage_filter = kwargs["min_mem_usage_filter"]
195197
min_mem_usage_ifft2 = False
198+
if "min_mem_usage_ifft2" in kwargs:
199+
min_mem_usage_ifft2 = kwargs["min_mem_usage_ifft2"]
196200

197201
angles_tot = non_slice_dims_shape[0]
198202
DetectorsLengthH_prepad = non_slice_dims_shape[1]
@@ -339,21 +343,21 @@ def add_to_memory_counters(amount, per_slice: bool):
339343

340344
add_to_memory_counters(tmp_p_input_slice, True)
341345
if min_mem_usage_filter:
342-
add_to_memory_counters(rfft_plan_slice_size / 4, False)
343-
add_to_memory_counters(irfft_plan_slice_size / 4, False)
344-
add_to_memory_counters(padded_tmp_p_input_slice, False)
346+
add_to_memory_counters(rfft_plan_slice_size / 4 / projection_chunk_count, False)
347+
add_to_memory_counters(irfft_plan_slice_size / 4 / projection_chunk_count, False)
348+
add_to_memory_counters(padded_tmp_p_input_slice / projection_chunk_count, False)
345349

346-
add_to_memory_counters(rfft_result_size, False)
347-
add_to_memory_counters(filtered_rfft_result_size, False)
348-
add_to_memory_counters(-rfft_result_size, False)
349-
add_to_memory_counters(-padded_tmp_p_input_slice, False)
350+
add_to_memory_counters(rfft_result_size / projection_chunk_count, False)
351+
add_to_memory_counters(filtered_rfft_result_size / projection_chunk_count, False)
352+
add_to_memory_counters(-rfft_result_size / projection_chunk_count, False)
353+
add_to_memory_counters(-padded_tmp_p_input_slice / projection_chunk_count, False)
350354

351-
add_to_memory_counters(irfft_scratch_memory_size, False)
352-
add_to_memory_counters(-irfft_scratch_memory_size, False)
353-
add_to_memory_counters(irfft_result_size, False)
354-
add_to_memory_counters(-filtered_rfft_result_size, False)
355+
add_to_memory_counters(irfft_scratch_memory_size / projection_chunk_count, False)
356+
add_to_memory_counters(-irfft_scratch_memory_size / projection_chunk_count, False)
357+
add_to_memory_counters(irfft_result_size / projection_chunk_count, False)
358+
add_to_memory_counters(-filtered_rfft_result_size / projection_chunk_count, False)
355359

356-
add_to_memory_counters(-irfft_result_size, False)
360+
add_to_memory_counters(-irfft_result_size / projection_chunk_count, False)
357361
else:
358362
add_to_memory_counters(rfft_plan_slice_size / chunk_count / projection_chunk_count * 2, True)
359363
add_to_memory_counters(irfft_plan_slice_size / chunk_count / projection_chunk_count * 2, True)
@@ -402,8 +406,10 @@ def add_to_memory_counters(amount, per_slice: bool):
402406
add_to_memory_counters(circular_mask_size, False)
403407
add_to_memory_counters(after_recon_swapaxis_slice, True)
404408

405-
if min_mem_usage_ifft2 and min_mem_usage_filter:
406-
return (tot_memory_bytes * 1.1 + 30 * 1024 * 1024, fixed_amount)
409+
if min_mem_usage_filter and min_mem_usage_ifft2:
410+
return (tot_memory_bytes * 1.15, fixed_amount)
411+
elif min_mem_usage_filter and not min_mem_usage_ifft2:
412+
return (tot_memory_bytes + 60 * 1024 * 1024, fixed_amount)
407413
else:
408414
return (tot_memory_bytes * 1.1, fixed_amount)
409415

tests/test_httomolibgpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ def test_recon_FBP3d_tomobar_memoryhook(
543543

544544

545545
@pytest.mark.cupy
546+
@pytest.mark.parametrize("min_mem_usage_filter_ifft2", [(False, False), (True, False), (True, True)])
546547
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
547548
@pytest.mark.parametrize("projections", [1500, 1801, 2560])
548549
@pytest.mark.parametrize("detX_size", [2560])
@@ -553,6 +554,7 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook(
553554
detX_size,
554555
projections,
555556
projection_angle_range,
557+
min_mem_usage_filter_ifft2,
556558
padding_detx,
557559
ensure_clean_memory,
558560
):
@@ -562,12 +564,14 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook(
562564
projections,
563565
projection_angle_range,
564566
padding_detx,
567+
min_mem_usage_filter_ifft2,
565568
ensure_clean_memory,
566569
)
567570

568571

569572
@pytest.mark.full
570573
@pytest.mark.cupy
574+
@pytest.mark.parametrize("min_mem_usage_filter_ifft2", [(False, False), (True, False), (True, True)])
571575
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100, 800])
572576
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
573577
@pytest.mark.parametrize("detX_size", [2560])
@@ -579,6 +583,7 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook_full(
579583
projections,
580584
projection_angle_range,
581585
padding_detx,
586+
min_mem_usage_filter_ifft2,
582587
ensure_clean_memory,
583588
):
584589
__test_recon_LPRec3d_tomobar_memoryhook_common(
@@ -587,12 +592,14 @@ def test_recon_LPRec3d_tomobar_0_pi_memoryhook_full(
587592
projections,
588593
projection_angle_range,
589594
padding_detx,
595+
min_mem_usage_filter_ifft2,
590596
ensure_clean_memory,
591597
)
592598

593599

594600
@pytest.mark.full
595601
@pytest.mark.cupy
602+
@pytest.mark.parametrize("min_mem_usage_filter_ifft2", [(False, False), (True, False), (True, True)])
596603
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100, 800])
597604
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
598605
@pytest.mark.parametrize("detX_size", [2560])
@@ -606,6 +613,7 @@ def test_recon_LPRec3d_tomobar_memoryhook_full(
606613
projections,
607614
projection_angle_range,
608615
padding_detx,
616+
min_mem_usage_filter_ifft2,
609617
ensure_clean_memory,
610618
):
611619
__test_recon_LPRec3d_tomobar_memoryhook_common(
@@ -614,6 +622,7 @@ def test_recon_LPRec3d_tomobar_memoryhook_full(
614622
projections,
615623
projection_angle_range,
616624
padding_detx,
625+
min_mem_usage_filter_ifft2,
617626
ensure_clean_memory,
618627
)
619628

@@ -624,6 +633,7 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
624633
projections,
625634
projection_angle_range,
626635
padding_detx,
636+
min_mem_usage_filter_ifft2,
627637
ensure_clean_memory,
628638
):
629639
angles_number = projections
@@ -634,6 +644,8 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
634644
)
635645
kwargs["center"] = 1280
636646
kwargs["detector_pad"] = padding_detx
647+
kwargs["min_mem_usage_filter"] = min_mem_usage_filter_ifft2[0]
648+
kwargs["min_mem_usage_ifft2"] = min_mem_usage_filter_ifft2[1]
637649
kwargs["recon_size"] = detX_size
638650
kwargs["recon_mask_radius"] = 0.8
639651

0 commit comments

Comments
 (0)