Skip to content

Commit d0e79b5

Browse files
authored
Merge pull request #71 from DiamondLightSource/lprec_options
LPRec memory estimator update with options
2 parents 52f5a58 + 2700905 commit d0e79b5

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

httomo_backends/methods_database/packages/backends/httomolibgpu/supporting_funcs/recon/algorithm.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def _calc_memory_bytes_LPRec3d_tomobar(
174174
else:
175175
detector_pad = 0
176176

177+
min_mem_usage_filter = False
178+
min_mem_usage_ifft2 = False
179+
177180
angles_tot = non_slice_dims_shape[0]
178181
DetectorsLengthH_prepad = non_slice_dims_shape[1]
179182
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
@@ -195,8 +198,8 @@ def _calc_memory_bytes_LPRec3d_tomobar(
195198
)
196199
)
197200

198-
center_size = 6144
199-
center_size = min(center_size, n * 2 + m * 2)
201+
center_size = 32768
202+
center_size = min(center_size, n * 2)
200203

201204
oversampling_level = 2 # at least 2 or larger required
202205
ne = oversampling_level * n
@@ -252,7 +255,7 @@ def _calc_memory_bytes_LPRec3d_tomobar(
252255
irfft_result_size = angles_tot * (n + padding_m * 2) * np.float32().itemsize
253256

254257
datac_size = angles_tot * n * np.complex64().itemsize / 2
255-
fde_size = (2 * m + 2 * n) * (2 * m + 2 * n) * np.complex64().itemsize / 2
258+
fde_size = 2 * n * 2 * n * np.complex64().itemsize / 2
256259
fft_plan_slice_size = (
257260
cufft_estimate_1d(nx=n, fft_type=CufftType.CUFFT_C2C, batch=angles_tot * SLICES)
258261
/ SLICES
@@ -270,7 +273,7 @@ def _calc_memory_bytes_LPRec3d_tomobar(
270273
)
271274
ifft2_plan_slice_size = (
272275
cufft_estimate_2d(
273-
nx=(2 * m + 2 * n), ny=(2 * m + 2 * n), fft_type=CufftType.CUFFT_C2C
276+
nx=2 * n, ny=2 * n, fft_type=CufftType.CUFFT_C2C
274277
)
275278
/ 2
276279
)
@@ -309,24 +312,40 @@ def add_to_memory_counters(amount, per_slice: bool):
309312
add_to_memory_counters(scaled_filter_size, False)
310313

311314
add_to_memory_counters(tmp_p_input_slice, True)
315+
if min_mem_usage_filter:
316+
add_to_memory_counters(rfft_plan_slice_size / 4, False)
317+
add_to_memory_counters(irfft_plan_slice_size / 4, False)
318+
add_to_memory_counters(padded_tmp_p_input_slice, False)
319+
320+
add_to_memory_counters(rfft_result_size, False)
321+
add_to_memory_counters(filtered_rfft_result_size, False)
322+
add_to_memory_counters(-rfft_result_size, False)
323+
add_to_memory_counters(-padded_tmp_p_input_slice, False)
324+
325+
add_to_memory_counters(irfft_scratch_memory_size, False)
326+
add_to_memory_counters(-irfft_scratch_memory_size, False)
327+
add_to_memory_counters(irfft_result_size, False)
328+
add_to_memory_counters(-filtered_rfft_result_size, False)
329+
330+
add_to_memory_counters(-irfft_result_size, False)
331+
else:
332+
add_to_memory_counters(rfft_plan_slice_size / chunk_count * 2, True)
333+
add_to_memory_counters(irfft_plan_slice_size / chunk_count * 2, True)
334+
# add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
335+
for _ in range(0, chunk_count):
336+
add_to_memory_counters(padded_tmp_p_input_slice / chunk_count, True)
312337

313-
add_to_memory_counters(rfft_plan_slice_size / chunk_count * 2, True)
314-
add_to_memory_counters(irfft_plan_slice_size / chunk_count * 2, True)
315-
# add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
316-
for _ in range(0, chunk_count):
317-
add_to_memory_counters(padded_tmp_p_input_slice / chunk_count, True)
318-
319-
add_to_memory_counters(rfft_result_size / chunk_count, True)
320-
add_to_memory_counters(filtered_rfft_result_size / chunk_count, True)
321-
add_to_memory_counters(-rfft_result_size / chunk_count, True)
322-
add_to_memory_counters(-padded_tmp_p_input_slice / chunk_count, True)
338+
add_to_memory_counters(rfft_result_size / chunk_count, True)
339+
add_to_memory_counters(filtered_rfft_result_size / chunk_count, True)
340+
add_to_memory_counters(-rfft_result_size / chunk_count, True)
341+
add_to_memory_counters(-padded_tmp_p_input_slice / chunk_count, True)
323342

324-
add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
325-
add_to_memory_counters(-irfft_scratch_memory_size / chunk_count, True)
326-
add_to_memory_counters(irfft_result_size / chunk_count, True)
327-
add_to_memory_counters(-filtered_rfft_result_size / chunk_count, True)
343+
add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
344+
add_to_memory_counters(-irfft_scratch_memory_size / chunk_count, True)
345+
add_to_memory_counters(irfft_result_size / chunk_count, True)
346+
add_to_memory_counters(-filtered_rfft_result_size / chunk_count, True)
328347

329-
add_to_memory_counters(-irfft_result_size / chunk_count, True)
348+
add_to_memory_counters(-irfft_result_size / chunk_count, True)
330349

331350
add_to_memory_counters(-padded_in_slice_size, True)
332351
add_to_memory_counters(-filter_size, False)
@@ -342,17 +361,25 @@ def add_to_memory_counters(amount, per_slice: bool):
342361

343362
add_to_memory_counters(-fft_result_size, True)
344363

345-
add_to_memory_counters(ifft2_plan_slice_size / chunk_count * 2, True)
346-
for _ in range(0, chunk_count):
347-
add_to_memory_counters(fde_size / chunk_count, True)
348-
add_to_memory_counters(-fde_size / chunk_count, True)
364+
if min_mem_usage_ifft2:
365+
add_to_memory_counters(ifft2_plan_slice_size, False)
366+
add_to_memory_counters(fde_size * 2, False)
367+
add_to_memory_counters(-fde_size * 2, False)
368+
else:
369+
add_to_memory_counters(ifft2_plan_slice_size / chunk_count * 2, True)
370+
for _ in range(0, chunk_count):
371+
add_to_memory_counters(fde_size / chunk_count, True)
372+
add_to_memory_counters(-fde_size / chunk_count, True)
349373

350374
add_to_memory_counters(recon_output_size, True)
351375
add_to_memory_counters(-fde_size, True)
352376
add_to_memory_counters(circular_mask_size, False)
353377
add_to_memory_counters(after_recon_swapaxis_slice, True)
354378

355-
return (tot_memory_bytes * 1.05, fixed_amount + 250 * 1024 * 1024)
379+
if min_mem_usage_ifft2 and min_mem_usage_filter:
380+
return (tot_memory_bytes * 1.1 + 30 * 1024 * 1024, fixed_amount)
381+
else:
382+
return (tot_memory_bytes, fixed_amount)
356383

357384

358385
def _calc_memory_bytes_SIRT3d_tomobar(

tests/test_httomolibgpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
675675
# the estimated_memory_mb should be LARGER or EQUAL to max_mem_mb
676676
# the resulting percent value should not deviate from max_mem on more than 20%
677677
assert estimated_memory_mb >= max_mem_mb
678-
assert percents_relative_maxmem <= 80
678+
assert percents_relative_maxmem <= 85
679679

680680

681681
@pytest.mark.cupy

0 commit comments

Comments
 (0)