Skip to content

Commit 5f52fc5

Browse files
Add examples/skip.py
1 parent bdf9fda commit 5f52fc5

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

python/examples/skip.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import sys
2+
3+
import cuda.cccl.headers as headers
4+
import cuda.core.experimental as core
5+
import cuda.nvbench as nvbench
6+
7+
8+
def make_sleep_kernel():
9+
"""JITs sleep_kernel(seconds)"""
10+
src = r"""
11+
#include <cuda/std/cstdint>
12+
#include <cuda/std/chrono>
13+
14+
// Each launched thread just sleeps for `seconds`.
15+
__global__ void sleep_kernel(double seconds) {
16+
namespace chrono = ::cuda::std::chrono;
17+
using hr_clock = chrono::high_resolution_clock;
18+
19+
auto duration = static_cast<cuda::std::int64_t>(seconds * 1e9);
20+
const auto ns = chrono::nanoseconds(duration);
21+
22+
const auto start = hr_clock::now();
23+
const auto finish = start + ns;
24+
25+
auto now = hr_clock::now();
26+
while (now < finish)
27+
{
28+
now = hr_clock::now();
29+
}
30+
}
31+
"""
32+
incl = headers.get_include_paths()
33+
opts = core.ProgramOptions(include_path=str(incl.libcudacxx))
34+
prog = core.Program(src, code_type="c++", options=opts)
35+
mod = prog.compile("cubin", name_expressions=("sleep_kernel",))
36+
return mod.get_kernel("sleep_kernel")
37+
38+
39+
def runtime_skip(state: nvbench.State):
40+
duration = state.getFloat64("Duration")
41+
kramble = state.getString("Kramble")
42+
43+
# Skip Baz benchmarks with 0.8 ms duration
44+
if kramble == "Baz" and duration < 0.8e-3:
45+
state.skip("Short 'Baz' benchmarks are skipped")
46+
return
47+
48+
# Skip Foo benchmark with > 0.3 ms duration
49+
if kramble == "Foo" and duration > 0.3e-3:
50+
state.skip("Long 'Foo' benchmarks are skipped")
51+
return
52+
53+
krn = make_sleep_kernel()
54+
launch_cfg = core.LaunchConfig(grid=1, block=1, shmem_size=0)
55+
56+
def launcher(launch: nvbench.Launch):
57+
dev = core.Device()
58+
dev.set_current()
59+
60+
s = dev.create_stream(launch.getStream())
61+
core.launch(s, launch_cfg, krn, duration)
62+
63+
state.exec(launcher)
64+
65+
66+
if __name__ == "__main__":
67+
b = nvbench.register(runtime_skip)
68+
b.addFloat64Axis("Duration", [1e-4 + k * 0.25e-3 for k in range(5)])
69+
b.addStringAxis("Kramble", ["Foo", "Bar", "Baz"])
70+
71+
nvbench.run_all_benchmarks(sys.argv)

0 commit comments

Comments
 (0)