Skip to content

Commit 428ddee

Browse files
nvbench.State.exec validates arg to be a callable
Add names to method arguments to make it more self-descriptive.
1 parent e709063 commit 428ddee

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

python/src/py_nvbench.cpp

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ PYBIND11_MODULE(_nvbench, m)
259259
self.add_int64_axis(std::move(name), std::move(data));
260260
return std::ref(self);
261261
},
262-
py::return_value_policy::reference);
262+
py::return_value_policy::reference,
263+
py::arg("name"),
264+
py::arg("values"));
263265
py_benchmark_cls.def(
264266
"add_int64_power_of_two_axis",
265267
[](nvbench::benchmark_base &self, std::string name, std::vector<nvbench::int64_t> data) {
@@ -268,42 +270,51 @@ PYBIND11_MODULE(_nvbench, m)
268270
nvbench::int64_axis_flags::power_of_two);
269271
return std::ref(self);
270272
},
271-
py::return_value_policy::reference);
273+
py::return_value_policy::reference,
274+
py::arg("name"),
275+
py::arg("values"));
272276
py_benchmark_cls.def(
273277
"add_float64_axis",
274278
[](nvbench::benchmark_base &self, std::string name, std::vector<nvbench::float64_t> data) {
275279
self.add_float64_axis(std::move(name), std::move(data));
276280
return std::ref(self);
277281
},
278-
py::return_value_policy::reference);
282+
py::return_value_policy::reference,
283+
py::arg("name"),
284+
py::arg("values"));
279285
py_benchmark_cls.def(
280286
"add_string_axis",
281287
[](nvbench::benchmark_base &self, std::string name, std::vector<std::string> data) {
282288
self.add_string_axis(std::move(name), std::move(data));
283289
return std::ref(self);
284290
},
285-
py::return_value_policy::reference);
291+
py::return_value_policy::reference,
292+
py::arg("name"),
293+
py::arg("values"));
286294
py_benchmark_cls.def(
287295
"set_name",
288296
[](nvbench::benchmark_base &self, std::string name) {
289297
self.set_name(std::move(name));
290298
return std::ref(self);
291299
},
292-
py::return_value_policy::reference);
300+
py::return_value_policy::reference,
301+
py::arg("name"));
293302
py_benchmark_cls.def(
294303
"set_is_cpu_only",
295304
[](nvbench::benchmark_base &self, bool is_cpu_only) {
296305
self.set_is_cpu_only(is_cpu_only);
297306
return std::ref(self);
298307
},
299-
py::return_value_policy::reference);
308+
py::return_value_policy::reference,
309+
py::arg("is_cpu_only"));
300310
py_benchmark_cls.def(
301311
"set_run_once",
302-
[](nvbench::benchmark_base &self, bool v) {
303-
self.set_run_once(v);
312+
[](nvbench::benchmark_base &self, bool run_once) {
313+
self.set_run_once(run_once);
304314
return std::ref(self);
305315
},
306-
py::return_value_policy::reference);
316+
py::return_value_policy::reference,
317+
py::arg("run_once"));
307318

308319
// == STEP 5
309320
// Define PyState class
@@ -421,7 +432,7 @@ PYBIND11_MODULE(_nvbench, m)
421432
&nvbench::state::add_element_count,
422433
py::arg("count"),
423434
py::arg("column_name") = py::str(""));
424-
pystate_cls.def("set_element_count", &nvbench::state::set_element_count);
435+
pystate_cls.def("set_element_count", &nvbench::state::set_element_count, py::arg("count"));
425436
pystate_cls.def("get_element_count", &nvbench::state::get_element_count);
426437

427438
pystate_cls.def("skip", &nvbench::state::skip, py::arg("reason"));
@@ -478,40 +489,49 @@ PYBIND11_MODULE(_nvbench, m)
478489

479490
pystate_cls.def(
480491
"exec",
481-
[](nvbench::state &state, py::object callable_fn, bool batched, bool sync) {
492+
[](nvbench::state &state, py::object py_launcher_fn, bool batched, bool sync) {
493+
if (!PyCallable_Check(py_launcher_fn.ptr()))
494+
{
495+
throw py::type_error("Argument of exec method must be a callable object");
496+
}
497+
482498
// wrapper to invoke Python callable
483-
auto launcher_fn = [callable_fn](nvbench::launch &launch_descr) -> void {
499+
auto cpp_launcher_fn = [py_launcher_fn](nvbench::launch &launch_descr) -> void {
484500
// cast C++ object to python object
485501
auto launch_pyarg = py::cast(std::ref(launch_descr), py::return_value_policy::reference);
486502
// call Python callable
487-
callable_fn(launch_pyarg);
503+
py_launcher_fn(launch_pyarg);
488504
};
489505

490506
if (sync)
491507
{
492508
if (batched)
493509
{
494-
state.exec(nvbench::exec_tag::sync, launcher_fn);
510+
constexpr auto tag = nvbench::exec_tag::sync;
511+
state.exec(tag, cpp_launcher_fn);
495512
}
496513
else
497514
{
498-
state.exec(nvbench::exec_tag::sync | nvbench::exec_tag::no_batch, launcher_fn);
515+
constexpr auto tag = nvbench::exec_tag::sync | nvbench::exec_tag::no_batch;
516+
state.exec(tag, cpp_launcher_fn);
499517
}
500518
}
501519
else
502520
{
503521
if (batched)
504522
{
505-
state.exec(nvbench::exec_tag::none, launcher_fn);
523+
constexpr auto tag = nvbench::exec_tag::none;
524+
state.exec(tag, cpp_launcher_fn);
506525
}
507526
else
508527
{
509-
state.exec(nvbench::exec_tag::no_batch, launcher_fn);
528+
constexpr auto tag = nvbench::exec_tag::no_batch;
529+
state.exec(tag, cpp_launcher_fn);
510530
}
511531
}
512532
},
513-
"Executor for given callable fn(state : Launch)",
514-
py::arg("fn"),
533+
"Executor for given launcher callable fn(state : Launch)",
534+
py::arg("launcher_fn"),
515535
py::pos_only{},
516536
py::arg("batched") = true,
517537
py::arg("sync") = false);
@@ -527,7 +547,7 @@ PYBIND11_MODULE(_nvbench, m)
527547
summ.set_string("name", std::move(column_name));
528548
summ.set_string("value", std::move(value));
529549
},
530-
py::arg("column_name"),
550+
py::arg("name"),
531551
py::arg("value"));
532552
pystate_cls.def(
533553
"add_summary",

0 commit comments

Comments
 (0)