Skip to content

Commit d99df23

Browse files
committed
Attempt to repair device memory fragmentation
1 parent 686570e commit d99df23

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,12 @@ def apply(
615615
return yl
616616

617617

618+
def _repair_memory_fragmentation_if_needed(fragmentation_threshold: float = 0.2):
619+
pool = cp.get_default_memory_pool()
620+
total = pool.total_bytes()
621+
if (total / pool.used_bytes()) - 1 > fragmentation_threshold:
622+
pool.free_all_blocks()
623+
618624
def remove_stripe_fw(
619625
data: cp.ndarray,
620626
sigma: float = 1,
@@ -710,7 +716,7 @@ def remove_stripe_fw(
710716
for c in cc:
711717
mem_stack.free(np.prod(c) * np.float32().itemsize)
712718
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
713-
return mem_stack.highwater
719+
return int(mem_stack.highwater * 1.1)
714720

715721
sli = cp.zeros(sli_shape, dtype="float32")
716722
sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2] = data.swapaxes(0, 1)
@@ -736,12 +742,14 @@ def remove_stripe_fw(
736742
cc[k][:, 0, 1] = cp.fft.ifft(ifft_in, my, axis=1).real
737743
del ifft_plan
738744
del ifft_in
745+
_repair_memory_fragmentation_if_needed()
739746

740747
# Wavelet reconstruction.
741748
for k in range(level)[::-1]:
742749
shape0 = cc[k][0, 0, 1].shape
743750
sli = sli[:, :, : shape0[0], : shape0[1]]
744751
sli = ifm.apply((sli, cc[k]))
752+
_repair_memory_fragmentation_if_needed()
745753

746754
data = sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2, :ni]
747755
data = data.swapaxes(0, 1)

tests/test_prep/test_stripe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
109109

110110
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
111111
@pytest.mark.parametrize(
112-
"slices", [177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105]
112+
"slices", [38, 177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105]
113113
)
114114
@pytest.mark.parametrize("level", [None, 7, 11])
115-
def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory):
116-
dim_y = 901
117-
dim_x = 1200
115+
@pytest.mark.parametrize("dims", [(901, 1200), (1801, 2560)])
116+
def test_remove_stripe_fw_calc_mem_big(wname, slices, level, dims, ensure_clean_memory):
117+
dim_y, dim_x = dims
118118
data_shape = (slices, dim_x, dim_y)
119119
try:
120120
estimated_mem_peak = remove_stripe_fw(

0 commit comments

Comments
 (0)