Skip to content

Commit b19cde3

Browse files
Avoid overloading get_int64_or_default as get_int64
Introduce get_int64_or_default method, and counterparts for float64 and string. Provided names for Python arguments. Tried generating Python stubs automatically with ``` stubgen -m cuda.nvbench._nvbench ``` Gave up on this, since it does not include doc-strings. It would be nice to compare auto-generated _nvbench.pyi with __init__.pyi for discrepancies though.
1 parent 1385565 commit b19cde3

File tree

2 files changed

+80
-42
lines changed

2 files changed

+80
-42
lines changed

python/cuda/nvbench/__init__.pyi

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,22 @@ class State:
9898
def get_stream(self) -> CudaStream:
9999
"CudaStream object from this configuration"
100100
...
101-
def get_int64(self, name: str, default_value: Optional[int] = None) -> int:
101+
def get_int64(self, name: str) -> int:
102102
"Get value for given Int64 axis from this configuration"
103103
...
104-
def get_float64(self, name: str, default_value: Optional[float] = None) -> float:
104+
def get_int64_or_default_value(self, name: str, default_value: int) -> int:
105+
"Get value for given Int64 axis from this configuration"
106+
...
107+
def get_float64(self, name: str) -> float:
108+
"Get value for given Float64 axis from this configuration"
109+
...
110+
def get_float64_or_default_value(self, name: str, default_value: float) -> float:
105111
"Get value for given Float64 axis from this configuration"
106112
...
107-
def get_string(self, name: str, default_value: Optional[str] = None) -> str:
113+
def get_string(self, name: str) -> str:
114+
"Get value for given String axis from this configuration"
115+
...
116+
def get_string_or_default_value(self, name: str, default_value: str) -> str:
108117
"Get value for given String axis from this configuration"
109118
...
110119
def add_element_count(self, count: int, column_name: Optional[str] = None) -> None:
@@ -140,7 +149,7 @@ class State:
140149
def get_min_samples(self) -> int:
141150
"Get the number of benchmark timings NVBench performs before stopping criterion begins being used"
142151
...
143-
def set_min_samples(self, count: int) -> None:
152+
def set_min_samples(self, min_samples_count: int) -> None:
144153
"Set the number of benchmark timings for NVBench to perform before stopping criterion begins being used"
145154
...
146155
def get_disable_blocking_kernel(self) -> bool:
@@ -152,20 +161,20 @@ class State:
152161
def get_run_once(self) -> bool:
153162
"Boolean flag whether configuration should only run once"
154163
...
155-
def set_run_once(self, flag: bool) -> None:
164+
def set_run_once(self, run_once_flag: bool) -> None:
156165
"Set run-once flag for this configuration"
157166
...
158167
def get_timeout(self) -> float:
159168
"Get time-out value for benchmark execution of this configuration"
160169
...
161170
def set_timeout(self, duration: float) -> None:
162-
"Set time-out value for benchmark execution of this configuration"
171+
"Set time-out value for benchmark execution of this configuration, in seconds"
163172
...
164173
def get_blocking_kernel_timeout(self) -> float:
165174
"Get time-out value for execution of blocking kernel"
166175
...
167176
def set_blocking_kernel_timeout(self, duration: float) -> None:
168-
"Set time-out value for execution of blocking kernel"
177+
"Set time-out value for execution of blocking kernel, in seconds"
169178
...
170179
def collect_cupti_metrics(self) -> None:
171180
"Request NVBench to record CUPTI metrics while running benchmark for this configuration"

python/src/py_nvbench.cpp

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -405,14 +405,26 @@ PYBIND11_MODULE(_nvbench, m)
405405
[](nvbench::state &state) { return std::ref(state.get_cuda_stream()); },
406406
py::return_value_policy::reference);
407407

408-
pystate_cls.def("get_int64", &nvbench::state::get_int64);
409-
pystate_cls.def("get_int64", &nvbench::state::get_int64_or_default);
410-
411-
pystate_cls.def("get_float64", &nvbench::state::get_float64);
412-
pystate_cls.def("get_float64", &nvbench::state::get_float64_or_default);
413-
414-
pystate_cls.def("get_string", &nvbench::state::get_string);
415-
pystate_cls.def("get_string", &nvbench::state::get_string_or_default);
408+
pystate_cls.def("get_int64", &nvbench::state::get_int64, py::arg("name"));
409+
pystate_cls.def("get_int64_or_default",
410+
&nvbench::state::get_int64_or_default,
411+
py::arg("name"),
412+
py::pos_only{},
413+
py::arg("default_value"));
414+
415+
pystate_cls.def("get_float64", &nvbench::state::get_float64, py::arg("name"));
416+
pystate_cls.def("get_float64_or_default",
417+
&nvbench::state::get_float64_or_default,
418+
py::arg("name"),
419+
py::pos_only{},
420+
py::arg("default_value"));
421+
422+
pystate_cls.def("get_string", &nvbench::state::get_string, py::arg("name"));
423+
pystate_cls.def("get_string_or_default",
424+
&nvbench::state::get_string_or_default,
425+
py::arg("name"),
426+
py::pos_only{},
427+
py::arg("default_value"));
416428

417429
pystate_cls.def("add_element_count",
418430
&nvbench::state::add_element_count,
@@ -421,7 +433,7 @@ PYBIND11_MODULE(_nvbench, m)
421433
pystate_cls.def("set_element_count", &nvbench::state::set_element_count);
422434
pystate_cls.def("get_element_count", &nvbench::state::get_element_count);
423435

424-
pystate_cls.def("skip", &nvbench::state::skip);
436+
pystate_cls.def("skip", &nvbench::state::skip, py::arg("reason"));
425437
pystate_cls.def("is_skipped", &nvbench::state::is_skipped);
426438
pystate_cls.def("get_skip_reason", &nvbench::state::get_skip_reason);
427439

@@ -450,19 +462,25 @@ PYBIND11_MODULE(_nvbench, m)
450462
pystate_cls.def("get_throttle_threshold", &nvbench::state::get_throttle_threshold);
451463

452464
pystate_cls.def("get_min_samples", &nvbench::state::get_min_samples);
453-
pystate_cls.def("set_min_samples", &nvbench::state::set_min_samples);
465+
pystate_cls.def("set_min_samples",
466+
&nvbench::state::set_min_samples,
467+
py::arg("min_samples_count"));
454468

455469
pystate_cls.def("get_disable_blocking_kernel", &nvbench::state::get_disable_blocking_kernel);
456-
pystate_cls.def("set_disable_blocking_kernel", &nvbench::state::set_disable_blocking_kernel);
470+
pystate_cls.def("set_disable_blocking_kernel",
471+
&nvbench::state::set_disable_blocking_kernel,
472+
py::arg("disable_blocking_kernel"));
457473

458474
pystate_cls.def("get_run_once", &nvbench::state::get_run_once);
459-
pystate_cls.def("set_run_once", &nvbench::state::set_run_once);
475+
pystate_cls.def("set_run_once", &nvbench::state::set_run_once, py::arg("run_once"));
460476

461477
pystate_cls.def("get_timeout", &nvbench::state::get_timeout);
462-
pystate_cls.def("set_timeout", &nvbench::state::set_timeout);
478+
pystate_cls.def("set_timeout", &nvbench::state::set_timeout, py::arg("duration"));
463479

464480
pystate_cls.def("get_blocking_kernel_timeout", &nvbench::state::get_blocking_kernel_timeout);
465-
pystate_cls.def("set_blocking_kernel_timeout", &nvbench::state::set_blocking_kernel_timeout);
481+
pystate_cls.def("set_blocking_kernel_timeout",
482+
&nvbench::state::set_blocking_kernel_timeout,
483+
py::arg("duration"));
466484

467485
pystate_cls.def("collect_cupti_metrics", &nvbench::state::collect_cupti_metrics);
468486
pystate_cls.def("is_cupti_required", &nvbench::state::is_cupti_required);
@@ -510,26 +528,36 @@ PYBIND11_MODULE(_nvbench, m)
510528
pystate_cls.def("get_short_description",
511529
[](const nvbench::state &state) { return state.get_short_description(); });
512530

513-
pystate_cls.def("add_summary",
514-
[](nvbench::state &state, std::string column_name, std::string value) {
515-
auto &summ = state.add_summary("nv/python/" + column_name);
516-
summ.set_string("description", "User tag: " + column_name);
517-
summ.set_string("name", std::move(column_name));
518-
summ.set_string("value", std::move(value));
519-
});
520-
pystate_cls.def("add_summary",
521-
[](nvbench::state &state, std::string column_name, std::int64_t value) {
522-
auto &summ = state.add_summary("nv/python/" + column_name);
523-
summ.set_string("description", "User tag: " + column_name);
524-
summ.set_string("name", std::move(column_name));
525-
summ.set_int64("value", value);
526-
});
527-
pystate_cls.def("add_summary", [](nvbench::state &state, std::string column_name, double value) {
528-
auto &summ = state.add_summary("nv/python/" + column_name);
529-
summ.set_string("description", "User tag: " + column_name);
530-
summ.set_string("name", std::move(column_name));
531-
summ.set_float64("value", value);
532-
});
531+
pystate_cls.def(
532+
"add_summary",
533+
[](nvbench::state &state, std::string column_name, std::string value) {
534+
auto &summ = state.add_summary("nv/python/" + column_name);
535+
summ.set_string("description", "User tag: " + column_name);
536+
summ.set_string("name", std::move(column_name));
537+
summ.set_string("value", std::move(value));
538+
},
539+
py::arg("column_name"),
540+
py::arg("value"));
541+
pystate_cls.def(
542+
"add_summary",
543+
[](nvbench::state &state, std::string column_name, std::int64_t value) {
544+
auto &summ = state.add_summary("nv/python/" + column_name);
545+
summ.set_string("description", "User tag: " + column_name);
546+
summ.set_string("name", std::move(column_name));
547+
summ.set_int64("value", value);
548+
},
549+
py::arg("name"),
550+
py::arg("value"));
551+
pystate_cls.def(
552+
"add_summary",
553+
[](nvbench::state &state, std::string column_name, double value) {
554+
auto &summ = state.add_summary("nv/python/" + column_name);
555+
summ.set_string("description", "User tag: " + column_name);
556+
summ.set_string("name", std::move(column_name));
557+
summ.set_float64("value", value);
558+
},
559+
py::arg("name"),
560+
py::arg("value"));
533561

534562
// Use handle to take a memory leak here, since this object's destructor may be called after
535563
// interpreter has shut down
@@ -546,7 +574,8 @@ PYBIND11_MODULE(_nvbench, m)
546574
"register",
547575
[&](py::object fn) { return std::ref(global_registry->add_bench(fn)); },
548576
"Register benchmark function of type Callable[[nvbench.State], None]",
549-
py::return_value_policy::reference);
577+
py::return_value_policy::reference,
578+
py::arg("benchmark_fn"));
550579

551580
m.def(
552581
"run_all_benchmarks",

0 commit comments

Comments
 (0)