Skip to content

Commit 938d46a

Browse files
Allow kernel_generator to be stateful
In python kernel generator is a user-defined callable. We need to capture Python object of that callable in kernel generator provided for each benchmark. To this end, nvbench::benchmark has been modified to have member of kernel_generator type (must be copy-constructable). Constructor acquires an optional parameter of type `kernel_generator` with default value of default-contstructed instance. nvbench::runner was modified to store kernel_generator instance as well. Its run method creates a fresh copy of stored instance for each invocation, just as it was happening before. nvbench tests/examples pass with this change.
1 parent b1551d2 commit 938d46a

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

nvbench/benchmark.cuh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,16 @@ struct benchmark final : public benchmark_base
5858

5959
static constexpr std::size_t num_type_configs = nvbench::tl::size<type_configs>{};
6060

61-
benchmark()
61+
benchmark(kernel_generator kgen = {})
6262
: benchmark_base(type_axes{})
63+
, m_kernel_generator(kgen)
6364
{}
6465

6566
private:
66-
std::unique_ptr<benchmark_base> do_clone() const final { return std::make_unique<benchmark>(); }
67+
std::unique_ptr<benchmark_base> do_clone() const final
68+
{
69+
return std::make_unique<benchmark>(this->m_kernel_generator);
70+
}
6771

6872
void do_set_type_axes_names(std::vector<std::string> names) final
6973
{
@@ -72,10 +76,12 @@ private:
7276

7377
void do_run() final
7478
{
75-
nvbench::runner<benchmark> runner{*this};
79+
nvbench::runner<benchmark> runner{*this, this->m_kernel_generator};
7680
runner.generate_states();
7781
runner.run();
7882
}
83+
84+
kernel_generator m_kernel_generator;
7985
};
8086

8187
} // namespace nvbench

nvbench/runner.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ struct runner : public runner_base
5454
using type_configs = typename benchmark_type::type_configs;
5555
static constexpr std::size_t num_type_configs = benchmark_type::num_type_configs;
5656

57-
explicit runner(benchmark_type &bench)
57+
explicit runner(benchmark_type &bench, kernel_generator kgen = {})
5858
: runner_base{bench}
59+
, m_kernel_generator{kgen}
5960
{}
6061

6162
void run()
@@ -98,7 +99,8 @@ private:
9899
self.run_state_prologue(cur_state);
99100
try
100101
{
101-
kernel_generator{}(cur_state, type_config{});
102+
auto kernel_generator_copy = self.m_kernel_generator;
103+
kernel_generator_copy(cur_state, type_config{});
102104
if (cur_state.is_skipped())
103105
{
104106
self.print_skip_notification(cur_state);
@@ -115,6 +117,8 @@ private:
115117
++type_config_index;
116118
});
117119
}
120+
121+
kernel_generator m_kernel_generator;
118122
};
119123

120124
} // namespace nvbench

0 commit comments

Comments
 (0)