Skip to content

Commit f945737

Browse files
committed
WIP: try using cuda streams
1 parent e1a9d8f commit f945737

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,23 @@ def remove_all_stripe(
201201
Corrected 3D tomographic data as a CuPy or NumPy array.
202202
203203
"""
204-
for m in range(data.shape[1]):
205-
data[:, m, :] = _rs_dead(data[:, m, :], snr, la_size)
206-
data[:, m, :] = _rs_sort(data[:, m, :], sm_size, dim)
207-
data[:, m, :] = cp.nan_to_num(data[:, m, :])
208-
209-
return data
204+
streams = [cp.cuda.Stream() for _ in range(4)]
205+
output = data.copy()
206+
def process_slice(m, stream):
207+
with stream:
208+
output[:, m, :] = _rs_dead(output[:, m, :], snr, la_size)
209+
output[:, m, :] = _rs_sort(output[:, m, :], sm_size, dim)
210+
output[:, m, :] = cp.nan_to_num(output[:, m, :])
211+
212+
# Distribute slices across streams
213+
for i in range(data.shape[1]):
214+
stream = streams[i % 4]
215+
process_slice(i, stream)
216+
217+
for stream in streams:
218+
stream.synchronize()
219+
220+
return output
210221

211222

212223
def _mpolyfit(x, y):

remove_all_stripe.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import cupy as cp
2+
import numpy as np
3+
import os
4+
import time
5+
from cupy.cuda import memory_hooks
6+
from datetime import datetime
7+
from math import isclose
8+
from cupyx.profiler import time_range
9+
10+
from httomolibgpu.prep.stripe import remove_all_stripe
11+
12+
test_data_path = "/mnt/gpfs03/scratch/data/imaging/tomography/zenodo"
13+
data_path = os.path.join(test_data_path, "synth_tomophantom1.npz")
14+
data_file = np.load(data_path)
15+
projdata = cp.asarray(cp.swapaxes(data_file["projdata"], 0, 1))
16+
angles = cp.asarray(data_file["angles"])
17+
18+
with time_range("all_stripe", color_id=0):
19+
remove_all_stripe(
20+
cp.copy(projdata),
21+
snr=0.1,
22+
la_size=71,
23+
sm_size=31,
24+
dim=1
25+
)
26+
27+
28+
# cold run
29+
remove_all_stripe(
30+
cp.copy(projdata),
31+
snr=0.1,
32+
la_size=71,
33+
sm_size=31,
34+
dim=1,
35+
)
36+
37+
dev = cp.cuda.Device()
38+
dev.synchronize()
39+
start = time.perf_counter_ns()
40+
for _ in range(10):
41+
remove_all_stripe(
42+
cp.copy(projdata),
43+
snr=0.1,
44+
la_size=71,
45+
sm_size=31,
46+
dim=1,
47+
)
48+
49+
dev.synchronize()
50+
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
51+
52+
print(duration_ms)
53+
54+
55+
output = remove_all_stripe(cp.copy(projdata), snr=0.1, la_size=61, sm_size=21, dim=1)
56+
residual_calc = projdata - output
57+
norm_res = cp.linalg.norm(residual_calc.flatten())
58+
assert isclose(norm_res, 67917.71, abs_tol=10**-2)
59+
60+
output = remove_all_stripe(cp.copy(projdata), snr=0.001, la_size=61, sm_size=21, dim=1)
61+
residual_calc = projdata - output
62+
norm_res = cp.linalg.norm(residual_calc.flatten())
63+
assert isclose(norm_res, 70015.51, abs_tol=10**-2)
64+
65+
hook = memory_hooks.LineProfileHook()
66+
with hook:
67+
remove_all_stripe(
68+
cp.copy(projdata),
69+
snr=0.1,
70+
la_size=71,
71+
sm_size=31,
72+
dim=1
73+
)
74+
hook.print_report()

0 commit comments

Comments
 (0)