Skip to content

Commit 686570e

Browse files
committed
Prepare tests for memory allocation while estimating
1 parent 324cf22 commit 686570e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tests/test_prep/test_stripe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
9696
remove_stripe_fw(cp.copy(data), wname=wname, level=level)
9797
actual_mem_peak = hook.max_mem
9898

99-
hook = MaxMemoryHook()
100-
with hook:
99+
try:
101100
estimated_mem_peak = remove_stripe_fw(
102101
data.shape, level=level, wname=wname, calc_peak_gpu_mem=True
103102
)
103+
except cp.cuda.memory.OutOfMemoryError:
104+
pytest.skip("Not enough GPU memory to estimate memory peak")
104105

105106
assert actual_mem_peak * 0.99 <= estimated_mem_peak
106-
assert estimated_mem_peak <= actual_mem_peak * 1.3
107+
assert estimated_mem_peak <= actual_mem_peak * 1.15
107108

108109

109110
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
@@ -115,11 +116,12 @@ def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory
115116
dim_y = 901
116117
dim_x = 1200
117118
data_shape = (slices, dim_x, dim_y)
118-
hook = MaxMemoryHook()
119-
with hook:
119+
try:
120120
estimated_mem_peak = remove_stripe_fw(
121121
data_shape, wname=wname, level=level, calc_peak_gpu_mem=True
122122
)
123+
except cp.cuda.memory.OutOfMemoryError:
124+
pytest.skip("Not enough GPU memory to estimate memory peak")
123125
av_mem = cp.cuda.Device().mem_info[0]
124126
if av_mem < estimated_mem_peak:
125127
pytest.skip("Not enough GPU memory to run this test")
@@ -131,7 +133,7 @@ def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory
131133
actual_mem_peak = hook.max_mem
132134

133135
assert actual_mem_peak * 0.99 <= estimated_mem_peak
134-
assert estimated_mem_peak <= actual_mem_peak * 1.3
136+
assert estimated_mem_peak <= actual_mem_peak * 1.15
135137

136138

137139
@pytest.mark.parametrize("angles", [180, 181])

0 commit comments

Comments
 (0)