Skip to content

Commit ad4d7bd

Browse files
committed
Update mem estimator and tests for very large sizes
1 parent beb1289 commit ad4d7bd

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,7 @@ def _conv_transpose2d(
337337
out_shape = [b, ci, hi, wi]
338338
if mem_stack:
339339
# The trouble here is that we allocate more than the returned size
340-
out_actual_bytes = np.prod(out_shape) * np.float32().itemsize
341-
mem_stack.malloc(out_actual_bytes)
340+
mem_stack.malloc((np.prod(out_shape) + w.size) * np.float32().itemsize)
342341
if pad != 0:
343342
new_out_shape = [
344343
out_shape[0],
@@ -347,7 +346,7 @@ def _conv_transpose2d(
347346
out_shape[3] - 2 * pad[1],
348347
]
349348
mem_stack.malloc(np.prod(new_out_shape) * np.float32().itemsize)
350-
mem_stack.free(np.prod(out_shape) * np.float32().itemsize)
349+
mem_stack.free((np.prod(out_shape) + w.size) * np.float32().itemsize)
351350
out_shape = new_out_shape
352351
return out_shape
353352

@@ -673,7 +672,7 @@ def remove_stripe_fw(
673672
mem_stack.free(np.prod(c) * np.float32().itemsize)
674673
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
675674
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
676-
return mem_stack.highwater
675+
return int(mem_stack.highwater * 1.1)
677676

678677
sli = cp.zeros(sli_shape, dtype="float32")
679678
sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2] = data.swapaxes(0, 1)

tests/test_prep/test_stripe.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def ensure_clean_memory():
8585

8686

8787
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
88-
@pytest.mark.parametrize("slices", [55, 80])
89-
@pytest.mark.parametrize("level", [None, 1, 3, 7, 11])
88+
@pytest.mark.parametrize("slices", [3, 7, 32, 61, 109, 120, 150])
89+
@pytest.mark.parametrize("level", [None, 1, 3, 11])
9090
@pytest.mark.parametrize("dim_x", [128, 140])
9191
def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_memory):
9292
dim_y = 159
@@ -104,31 +104,32 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
104104
assert hook.max_mem == 0
105105

106106
assert actual_mem_peak * 0.99 <= estimated_mem_peak
107-
assert estimated_mem_peak <= actual_mem_peak * 1.2
107+
assert estimated_mem_peak <= actual_mem_peak * 1.3
108108

109109

110-
@pytest.mark.parametrize("wname", ['db4', 'sym16'])
110+
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
111111
@pytest.mark.parametrize("slices", [177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105])
112-
def test_remove_stripe_fw_calc_mem_big(wname, slices, ensure_clean_memory):
112+
@pytest.mark.parametrize("level", [None, 7, 11])
113+
def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory):
113114
dim_y = 901
114115
dim_x = 1200
115116
data_shape = (slices, dim_x, dim_y)
116117
hook = MaxMemoryHook()
117118
with hook:
118-
estimated_mem_peak = remove_stripe_fw(data_shape, wname=wname, calc_peak_gpu_mem=True)
119+
estimated_mem_peak = remove_stripe_fw(data_shape, wname=wname, level=level, calc_peak_gpu_mem=True)
119120
assert hook.max_mem == 0
120121
av_mem = cp.cuda.Device().mem_info[0]
121-
if av_mem < estimated_mem_peak * 1.1:
122+
if av_mem < estimated_mem_peak:
122123
pytest.skip("Not enough GPU memory to run this test")
123124

124125
hook = MaxMemoryHook()
125126
with hook:
126127
data = cp.random.random_sample(data_shape, dtype=np.float32)
127-
remove_stripe_fw(data, wname=wname)
128+
remove_stripe_fw(data, wname=wname, level=level)
128129
actual_mem_peak = hook.max_mem
129130

130131
assert actual_mem_peak * 0.99 <= estimated_mem_peak
131-
assert estimated_mem_peak <= actual_mem_peak * 1.2
132+
assert estimated_mem_peak <= actual_mem_peak * 1.3
132133

133134

134135
@pytest.mark.parametrize("angles", [180, 181])

0 commit comments

Comments
 (0)