Skip to content

Commit 26467f3

Browse files
committed
More cleanup
1 parent 4c964d2 commit 26467f3

File tree

4 files changed

+84
-31
lines changed

4 files changed

+84
-31
lines changed

nvbench/axes_metadata.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ struct axes_metadata
7777
Args &&...args)
7878
{
7979
(this->add_axis(std::forward<Args>(args)), ...);
80-
this->user_iteration_axes({args.get_name()...}, std::move(make));
80+
this->user_iteration_axes(std::move(make), {args.get_name()...});
8181
}
8282

8383
void zip_axes(std::vector<std::string> names);
8484

8585
void
86-
user_iteration_axes(std::vector<std::string> names,
87-
std::function<nvbench::make_user_space_signature> make);
86+
user_iteration_axes(std::function<nvbench::make_user_space_signature> make,
87+
std::vector<std::string> names);
8888

8989
[[nodiscard]] const axes_iteration_space &get_type_iteration_space() const
9090
{

nvbench/axes_metadata.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ void axes_metadata::zip_axes(std::vector<std::string> names)
265265
}
266266

267267
void axes_metadata::user_iteration_axes(
268-
std::vector<std::string> names,
269-
std::function<nvbench::make_user_space_signature> make)
268+
std::function<nvbench::make_user_space_signature> make,
269+
std::vector<std::string> names)
270270
{
271271
// compute the numeric indice for each name we have
272272
auto [input_indices,

nvbench/benchmark_base.cuh

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,28 +118,13 @@ struct benchmark_base
118118
return *this;
119119
}
120120

121-
benchmark_base &zip_axes(std::vector<std::string> names)
122-
{
123-
m_axes.zip_axes(std::move(names));
124-
return *this;
125-
}
126-
127121
template<typename... Args>
128122
benchmark_base &add_user_iteration_axes(Args&&... args)
129123
{
130124
m_axes.add_user_iteration_axes(std::forward<Args>(args)...);
131125
return *this;
132126
}
133127

134-
benchmark_base &
135-
user_iteration_axes(std::vector<std::string> names,
136-
std::function<nvbench::make_user_space_signature> make)
137-
{
138-
m_axes.user_iteration_axes(std::move(names), std::move(make));
139-
return *this;
140-
}
141-
142-
143128
benchmark_base &set_devices(std::vector<int> device_ids);
144129

145130
benchmark_base &set_devices(std::vector<nvbench::device_info> devices)
@@ -272,6 +257,38 @@ struct benchmark_base
272257
/// @}
273258

274259
protected:
260+
261+
/// Move existing Axis to being part of zip axis iteration space.
262+
/// This will remove any existing iteration spaces that the named axis
263+
/// are part of, while restoring all other axis in those spaces to
264+
/// the default linear space
265+
///
266+
/// This is meant to be used only by the option_parser
267+
/// @{
268+
benchmark_base &zip_axes(std::vector<std::string> names)
269+
{
270+
m_axes.zip_axes(std::move(names));
271+
return *this;
272+
}
273+
/// @}
274+
275+
276+
/// Move existing Axis to being part of user axis iteration space.
277+
/// This will remove any existing iteration spaces that the named axis
278+
/// are part of, while restoring all other axis in those spaces to
279+
/// the default linear space
280+
///
281+
/// This is meant to be used only by the option_parser
282+
/// @{
283+
benchmark_base &
284+
user_iteration_axes(std::function<nvbench::make_user_space_signature> make,
285+
std::vector<std::string> names)
286+
{
287+
m_axes.user_iteration_axes(std::move(make), std::move(names));
288+
return *this;
289+
}
290+
/// @}
291+
275292
friend struct nvbench::runner_base;
276293

277294
template <typename BenchmarkType>

testing/axes_iteration_space.cu

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,42 @@ void no_op_generator(nvbench::state &state)
6060
}
6161
NVBENCH_DEFINE_CALLABLE(no_op_generator, no_op_callable);
6262

63+
template <typename KernelGenerator, typename TypeAxes = nvbench::type_list<>>
64+
struct rezippable_benchmark final : public nvbench::benchmark_base
65+
{
66+
using kernel_generator = KernelGenerator;
67+
using type_axes = TypeAxes;
68+
using type_configs = nvbench::tl::cartesian_product<type_axes>;
69+
70+
static constexpr std::size_t num_type_configs =
71+
nvbench::tl::size<type_configs>{};
72+
73+
rezippable_benchmark()
74+
: benchmark_base(type_axes{})
75+
{}
76+
77+
using nvbench::benchmark_base::zip_axes;
78+
using nvbench::benchmark_base::user_iteration_axes;
79+
80+
private:
81+
std::unique_ptr<benchmark_base> do_clone() const final
82+
{
83+
return std::make_unique<rezippable_benchmark>();
84+
}
85+
86+
void do_set_type_axes_names(std::vector<std::string> names) final
87+
{
88+
m_axes.set_type_axes_names(std::move(names));
89+
}
90+
91+
void do_run() final
92+
{
93+
nvbench::runner<rezippable_benchmark> runner{*this};
94+
runner.generate_states();
95+
runner.run();
96+
}
97+
};
98+
6399
template <typename Integer, typename Float, typename Other>
64100
void template_no_op_generator(nvbench::state &state,
65101
nvbench::type_list<Integer, Float, Other>)
@@ -91,7 +127,7 @@ void test_zip_axes()
91127

92128
void test_tie_invalid_names()
93129
{
94-
using benchmark_type = nvbench::benchmark<no_op_callable>;
130+
using benchmark_type = rezippable_benchmark<no_op_callable>;
95131
benchmark_type bench;
96132
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
97133
bench.add_int64_axis("I64 Axis", {1, 3, 2});
@@ -114,11 +150,11 @@ void test_tie_unequal_length()
114150

115151
void test_tie_type_axi()
116152
{
117-
using benchmark_type =
118-
nvbench::benchmark<template_no_op_callable,
119-
nvbench::type_list<nvbench::type_list<nvbench::int8_t>,
120-
nvbench::type_list<nvbench::float32_t>,
121-
nvbench::type_list<bool>>>;
153+
using benchmark_type = rezippable_benchmark<
154+
template_no_op_callable,
155+
nvbench::type_list<nvbench::type_list<nvbench::int8_t>,
156+
nvbench::type_list<nvbench::float32_t>,
157+
nvbench::type_list<bool>>>;
122158
benchmark_type bench;
123159
bench.set_type_axes_names({"Integer", "Float", "Other"});
124160
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
@@ -129,7 +165,7 @@ void test_tie_type_axi()
129165

130166
void test_rezip_axes()
131167
{
132-
using benchmark_type = nvbench::benchmark<no_op_callable>;
168+
using benchmark_type = rezippable_benchmark<no_op_callable>;
133169
benchmark_type bench;
134170
bench.add_int64_axis("IAxis_A", {1, 3, 2, 4, 5});
135171
bench.add_int64_axis("IAxis_B", {1, 3, 2, 4, 5});
@@ -155,7 +191,7 @@ void test_rezip_axes()
155191

156192
void test_rezip_axes2()
157193
{
158-
using benchmark_type = nvbench::benchmark<no_op_callable>;
194+
using benchmark_type = rezippable_benchmark<no_op_callable>;
159195
benchmark_type bench;
160196
bench.add_int64_axis("IAxis_A", {1, 3, 2, 4, 5});
161197
bench.add_int64_axis("IAxis_B", {1, 3, 2, 4, 5});
@@ -298,15 +334,15 @@ struct under_diag final : nvbench::user_axis_space
298334

299335
void test_user_axes()
300336
{
301-
using benchmark_type = nvbench::benchmark<no_op_callable>;
337+
using benchmark_type = rezippable_benchmark<no_op_callable>;
302338
benchmark_type bench;
303339
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
304340
bench.add_int64_axis("I64 Axis", {1, 3, 2, 4, 5});
305341
bench.user_iteration_axes(
306-
{"F64 Axis", "I64 Axis"},
307342
[](auto... args) -> std::unique_ptr<nvbench::axis_space_base> {
308343
return std::make_unique<under_diag>(args...);
309-
});
344+
},
345+
{"F64 Axis", "I64 Axis"});
310346

311347
ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(),
312348
"Got {}",

0 commit comments

Comments
 (0)