Skip to content

Commit a86ecb6

Browse files
Add State.get_axis_values and State.get_axis_values_as_string
Add nvbench.State methods to get Python dictionary representing axis values of benchmark configuration state represents. get_axis_values_as_string gives a string of space-separated name=values pairs.
1 parent 428ddee commit a86ecb6

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

python/cuda/nvbench/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ class State:
238238
def add_summary(self, column_name: str, value: Union[int, float, str]) -> None:
239239
"Add summary column with a value"
240240
...
241+
def get_axis_values(self) -> dict[str, int | float | str]:
242+
"Get dictionary with axis values for this configuration"
243+
...
244+
def get_axis_values_as_string(self) -> str:
245+
"Get string of space-separated name=value pairs for this configuration"
246+
...
241247

242248
def register(fn: Callable[[State], None]) -> Benchmark:
243249
"""

python/src/py_nvbench.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,25 @@ class GlobalBenchmarkRegistry
195195
}
196196
};
197197

198+
py::dict py_get_axis_values(const nvbench::state &state)
199+
{
200+
auto named_values = state.get_axis_values();
201+
202+
auto names = named_values.get_names();
203+
py::dict res;
204+
205+
for (const auto &name : names)
206+
{
207+
if (named_values.has_value(name))
208+
{
209+
auto v = named_values.get_value(name);
210+
res[name.c_str()] = py::cast(v);
211+
}
212+
}
213+
214+
return res;
215+
}
216+
198217
// essentially a global variable, but allocated on the heap during module initialization
199218
constinit std::unique_ptr<GlobalBenchmarkRegistry, py::nodelete> global_registry{};
200219

@@ -569,6 +588,9 @@ PYBIND11_MODULE(_nvbench, m)
569588
},
570589
py::arg("name"),
571590
py::arg("value"));
591+
pystate_cls.def("get_axis_values_as_string",
592+
[](const nvbench::state &state) { return state.get_axis_values_as_string(); });
593+
pystate_cls.def("get_axis_values", &py_get_axis_values);
572594

573595
// Use handle to take a memory leak here, since this object's destructor may be called after
574596
// interpreter has shut down

0 commit comments

Comments
 (0)