Skip to content

Commit d5a6b1a

Browse files
Add implementation of and signature for State.getDevice
make batch/sync arguments of State.exec keyword-only Provide default column_name value for State.addElementCount method, so that it can be called state.addElementCount(count), or as state.addElementCount(count, column_name="Descriptive Name")
1 parent 3f1feb2 commit d5a6b1a

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

python/cuda/nvbench/__init__.pyi

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# from __future__ import annotations
22

33
from collections.abc import Callable, Sequence
4-
from typing import Optional, Self
4+
from typing import Optional, Self, Union
55

66
class CudaStream:
77
"""Represents CUDA stream
@@ -15,6 +15,18 @@ class CudaStream:
1515
Special method implement CUDA stream protocol
1616
from `cuda.core`. Returns a pair of integers:
1717
(protocol_version, integral_value_of_cudaStream_t pointer)
18+
19+
Example
20+
-------
21+
import cuda.core.experimental as core
22+
import cuda.nvbench as nvbench
23+
24+
def bench(state: nvbench.State):
25+
dev = core.Device(state.getDevice())
26+
dev.set_current()
27+
# converts CudaString to core.Stream
28+
# using __cuda_stream__ protocol
29+
dev.create_stream(state.getStream())
1830
"""
1931
...
2032

@@ -68,6 +80,9 @@ class State:
6880
def hasPrinters(self) -> bool:
6981
"True if configuration has a printer"
7082
...
83+
def getDevice(self) -> Union[int, None]:
84+
"Get device_id of the device from this configuration"
85+
...
7186
def getStream(self) -> CudaStream:
7287
"CudaStream object from this configuration"
7388
...
@@ -150,6 +165,8 @@ class State:
150165
def exec(
151166
self,
152167
fn: Callable[[Launch], None],
168+
/,
169+
*,
153170
batched: Optional[bool] = True,
154171
sync: Optional[bool] = False,
155172
):

python/src/py_nvbench.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,14 @@ PYBIND11_MODULE(_nvbench, m)
344344
pystate_cls.def("hasPrinters", [](nvbench::state &state) -> bool {
345345
return state.get_benchmark().get_printer().has_value();
346346
});
347+
pystate_cls.def("getDevice", [](nvbench::state &state) {
348+
auto dev = state.get_device();
349+
if (dev.has_value())
350+
{
351+
return py::cast(dev.value().get_id());
352+
}
353+
return py::object(py::none());
354+
});
347355

348356
pystate_cls.def(
349357
"getStream",
@@ -359,7 +367,10 @@ PYBIND11_MODULE(_nvbench, m)
359367
pystate_cls.def("getString", &nvbench::state::get_string);
360368
pystate_cls.def("getString", &nvbench::state::get_string_or_default);
361369

362-
pystate_cls.def("addElementCount", &nvbench::state::add_element_count);
370+
pystate_cls.def("addElementCount",
371+
&nvbench::state::add_element_count,
372+
py::arg("count"),
373+
py::arg("column_name") = py::str(""));
363374
pystate_cls.def("setElementCount", &nvbench::state::set_element_count);
364375
pystate_cls.def("getElementCount", &nvbench::state::get_element_count);
365376

0 commit comments

Comments
 (0)