@@ -60,6 +60,42 @@ void no_op_generator(nvbench::state &state)
6060}
6161NVBENCH_DEFINE_CALLABLE (no_op_generator, no_op_callable);
6262
63+ template <typename KernelGenerator, typename TypeAxes = nvbench::type_list<>>
64+ struct rezippable_benchmark final : public nvbench::benchmark_base
65+ {
66+ using kernel_generator = KernelGenerator;
67+ using type_axes = TypeAxes;
68+ using type_configs = nvbench::tl::cartesian_product<type_axes>;
69+
70+ static constexpr std::size_t num_type_configs =
71+ nvbench::tl::size<type_configs>{};
72+
73+ rezippable_benchmark ()
74+ : benchmark_base(type_axes{})
75+ {}
76+
77+ using nvbench::benchmark_base::zip_axes;
78+ using nvbench::benchmark_base::user_iteration_axes;
79+
80+ private:
81+ std::unique_ptr<benchmark_base> do_clone () const final
82+ {
83+ return std::make_unique<rezippable_benchmark>();
84+ }
85+
86+ void do_set_type_axes_names (std::vector<std::string> names) final
87+ {
88+ m_axes.set_type_axes_names (std::move (names));
89+ }
90+
91+ void do_run () final
92+ {
93+ nvbench::runner<rezippable_benchmark> runner{*this };
94+ runner.generate_states ();
95+ runner.run ();
96+ }
97+ };
98+
6399template <typename Integer, typename Float, typename Other>
64100void template_no_op_generator (nvbench::state &state,
65101 nvbench::type_list<Integer, Float, Other>)
@@ -91,7 +127,7 @@ void test_zip_axes()
91127
92128void test_tie_invalid_names ()
93129{
94- using benchmark_type = nvbench::benchmark <no_op_callable>;
130+ using benchmark_type = rezippable_benchmark <no_op_callable>;
95131 benchmark_type bench;
96132 bench.add_float64_axis (" F64 Axis" , {0 ., .1 , .25 , .5 , 1 .});
97133 bench.add_int64_axis (" I64 Axis" , {1 , 3 , 2 });
@@ -114,11 +150,11 @@ void test_tie_unequal_length()
114150
115151void test_tie_type_axi ()
116152{
117- using benchmark_type =
118- nvbench::benchmark< template_no_op_callable,
119- nvbench::type_list<nvbench::type_list<nvbench::int8_t >,
120- nvbench::type_list<nvbench::float32_t >,
121- nvbench::type_list<bool >>>;
153+ using benchmark_type = rezippable_benchmark<
154+ template_no_op_callable,
155+ nvbench::type_list<nvbench::type_list<nvbench::int8_t >,
156+ nvbench::type_list<nvbench::float32_t >,
157+ nvbench::type_list<bool >>>;
122158 benchmark_type bench;
123159 bench.set_type_axes_names ({" Integer" , " Float" , " Other" });
124160 bench.add_float64_axis (" F64 Axis" , {0 ., .1 , .25 , .5 , 1 .});
@@ -129,7 +165,7 @@ void test_tie_type_axi()
129165
130166void test_rezip_axes ()
131167{
132- using benchmark_type = nvbench::benchmark <no_op_callable>;
168+ using benchmark_type = rezippable_benchmark <no_op_callable>;
133169 benchmark_type bench;
134170 bench.add_int64_axis (" IAxis_A" , {1 , 3 , 2 , 4 , 5 });
135171 bench.add_int64_axis (" IAxis_B" , {1 , 3 , 2 , 4 , 5 });
@@ -155,7 +191,7 @@ void test_rezip_axes()
155191
156192void test_rezip_axes2 ()
157193{
158- using benchmark_type = nvbench::benchmark <no_op_callable>;
194+ using benchmark_type = rezippable_benchmark <no_op_callable>;
159195 benchmark_type bench;
160196 bench.add_int64_axis (" IAxis_A" , {1 , 3 , 2 , 4 , 5 });
161197 bench.add_int64_axis (" IAxis_B" , {1 , 3 , 2 , 4 , 5 });
@@ -298,15 +334,15 @@ struct under_diag final : nvbench::user_axis_space
298334
299335void test_user_axes ()
300336{
301- using benchmark_type = nvbench::benchmark <no_op_callable>;
337+ using benchmark_type = rezippable_benchmark <no_op_callable>;
302338 benchmark_type bench;
303339 bench.add_float64_axis (" F64 Axis" , {0 ., .1 , .25 , .5 , 1 .});
304340 bench.add_int64_axis (" I64 Axis" , {1 , 3 , 2 , 4 , 5 });
305341 bench.user_iteration_axes (
306- {" F64 Axis" , " I64 Axis" },
307342 [](auto ... args) -> std::unique_ptr<nvbench::axis_space_base> {
308343 return std::make_unique<under_diag>(args...);
309- });
344+ },
345+ {" F64 Axis" , " I64 Axis" });
310346
311347 ASSERT_MSG (bench.get_config_count () == 15 * bench.get_devices ().size (),
312348 " Got {}" ,
0 commit comments