Skip to content

Commit 37428fb

Browse files
Add Python stub file for cuda.nvbench API
1 parent c07a84f commit 37428fb

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed

python/cuda/nvbench/__init__.pyi

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from typing import Callable, Sequence, Tuple
2+
3+
class CudaStream:
4+
"""Represents CUDA stream
5+
6+
Note
7+
----
8+
The class is not directly constructible.
9+
"""
10+
def __cuda_stream__(self) -> Tuple[int]:
11+
"""
12+
Special method implement CUDA stream protocol
13+
from `cuda.core`. Returns a pair of integers:
14+
(protocol_version, integral_value_of_cudaStream_t pointer)
15+
"""
16+
...
17+
18+
def addressof(self) -> int:
19+
"Integral value of address of driver's CUDA stream struct"
20+
...
21+
22+
class Benchmark:
23+
"""Represents NVBench benchmark.
24+
25+
Note
26+
----
27+
The class is not user-constructible.
28+
Use `~register` function to create Benchmark and register
29+
it with NVBench.
30+
"""
31+
def getName(self) -> str:
32+
"Get benchmark name"
33+
...
34+
def addInt64Axis(self, name: str, values: Sequence[int]) -> Benchmark:
35+
"Add integral type parameter axis with given name and values to sweep over"
36+
...
37+
def addFloat64Axis(self, name: str, values: Sequence[float]) -> Benchmark:
38+
"Add floating-point type parameter axis with given name and values to sweep over"
39+
...
40+
def addStringAxis(sef, name: str, values: Sequence[str]) -> Benchmark:
41+
"Add string type parameter axis with given name and values to sweep over"
42+
...
43+
44+
class Launch:
45+
"""Configuration object for function launch.
46+
47+
Note
48+
----
49+
The class is not user-constructible.
50+
"""
51+
def getStream(self) -> CudaStream:
52+
"Get CUDA stream of this configuration"
53+
...
54+
55+
class State:
56+
"""Represent benchmark configuration state.
57+
58+
Note
59+
----
60+
The class is not user-constructible.
61+
"""
62+
def hasDevice(self) -> bool:
63+
"True if configuration has a device"
64+
...
65+
def hasPrinters(self) -> bool:
66+
"True if configuration has a printer"
67+
...
68+
def getStream(self) -> CudaStream:
69+
"CudaStream object from this configuration"
70+
...
71+
def getInt64(self, name: str, default_value: int = None) -> int:
72+
"Get value for given Int64 axis from this configuration"
73+
...
74+
def getFloat64(self, name: str, default_value: float = None) -> float:
75+
"Get value for given Float64 axis from this configuration"
76+
...
77+
def getString(self, name: str, default_value: str = None) -> str:
78+
"Get value for given String axis from this configuration"
79+
...
80+
def addElementCount(self, count: int, column_name: str = None) -> None:
81+
"Add element count"
82+
...
83+
def setElementCount(self, count: int) -> None:
84+
"Set element count"
85+
...
86+
def getElementCount(self) -> int:
87+
"Get element count"
88+
...
89+
def skip(self, reason: str) -> None:
90+
"Skip this configuration"
91+
...
92+
def isSkipped(self) -> bool:
93+
"Has this configuration been skipped"
94+
...
95+
def getSkipReason(self) -> str:
96+
"Get reason provided for skipping this configuration"
97+
...
98+
def addGlobalMemoryReads(self, nbytes: int) -> None:
99+
"Inform NVBench that given amount of bytes is being read by the benchmark from global memory"
100+
...
101+
def addGlobalMemoryWrites(self, nbytes: int) -> None:
102+
"Inform NVBench that given amount of bytes is being written by the benchmark into global memory"
103+
...
104+
def getBenchmark(self) -> Benchmark:
105+
"Get Benchmark this configuration is a part of"
106+
...
107+
def getThrottleThreshold(self) -> float:
108+
"Get throttle threshold value"
109+
...
110+
def getMinSamples(self) -> int:
111+
"Get the number of benchmark timings NVBench performs before stopping criterion begins being used"
112+
...
113+
def setMinSamples(self, count: int) -> None:
114+
"Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used"
115+
...
116+
def getDisableBlockingKernel(self) -> bool:
117+
"True if use of blocking kernel by NVBench is disabled, False otherwise"
118+
...
119+
def setDisableBlockingKernel(self, flag: bool) -> None:
120+
"Use flag = True to disable use of blocking kernel by NVBench"
121+
...
122+
def getRunOnce(self) -> bool:
123+
"Boolean flag whether configuration should only run once"
124+
...
125+
126+
def setRunOnce(self, flag: bool) -> None:
127+
"Set run-once flag for this configuration"
128+
...
129+
def getTimeout(self) -> float:
130+
"Get time-out value for benchmark execution of this configuration"
131+
...
132+
def setTimeout(self, duration: float) -> None:
133+
"Set time-out value for benchmark execution of this configuration"
134+
...
135+
def getBlockingKernelTimeout(self) -> float:
136+
"Get time-out value for execution of blocking kernel"
137+
...
138+
def setBlockingKernelTimeout(self, duration: float) -> None:
139+
"Set time-out value for execution of blocking kernel"
140+
...
141+
def collectCUPTIMetrics(self) -> None:
142+
"Request NVBench to record CUPTI metrics while running benchmark for this configuration"
143+
...
144+
def isCUPTIRequired(self) -> bool:
145+
"True if (some) CUPTI metrics are being collected"
146+
...
147+
def exec(
148+
self, fn: Callable[[Launch], None], batched: bool = True, sync: bool = False
149+
):
150+
"""Execute callable running the benchmark.
151+
152+
The callable may be executed multiple times.
153+
154+
Parameters
155+
----------
156+
fn: Callable
157+
Python callable with signature fn(Launch) -> None that executes the benchmark.
158+
batched: bool, optional
159+
If `True`, no cache flushing is performed between callable invocations.
160+
Default: `True`.
161+
sync: bool, optional
162+
True value indicates that callable performs device synchronization.
163+
NVBench disables use of blocking kernel in this case.
164+
Default: `False`.
165+
"""
166+
...
167+
168+
def register(fn: Callable[[State], None]) -> Benchmark:
169+
"""
170+
Register bencharking function with NVBench.
171+
"""
172+
...
173+
174+
def run_all_benchmarks(argv: Sequence[str]) -> None:
175+
"""
176+
Run all benchmarks registered with NVBench.
177+
178+
Parameters
179+
----------
180+
argv: List[str]
181+
Sequence of CLI arguments controlling NVBench. Usually, it is `sys.argv`.
182+
"""
183+
...

0 commit comments

Comments
 (0)