Skip to content

Commit 0a2130f

Browse files
committed
implement easier API to add axis and zip/user iteration at the same time
1 parent a25f578 commit 0a2130f

File tree

9 files changed

+118
-64
lines changed

9 files changed

+118
-64
lines changed

examples/custom_iteration_spaces.cu

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ void tied_copy_sweep_grid_shape(nvbench::state &state)
7272
}
7373
NVBENCH_BENCH(tied_copy_sweep_grid_shape)
7474
// Every power of two from 64->1024:
75-
.add_int64_axis("BlockSize", {32,64,128,256})
76-
.add_int64_axis("NumBlocks", {1024,512,256,128})
77-
.zip_axes({"BlockSize", "NumBlocks"});
75+
.zip_axes( nvbench::int64_axis{"BlockSize", {32,64,128,256}},
76+
nvbench::int64_axis{"NumBlocks", {1024,512,256,128}});
7877

7978
//==============================================================================
8079
// under_diag:
@@ -154,15 +153,12 @@ void user_copy_sweep_grid_shape(nvbench::state &state)
154153
copy_sweep_grid_shape(state);
155154
}
156155
NVBENCH_BENCH(user_copy_sweep_grid_shape)
157-
// Every power of two from 64->1024:
158-
.add_int64_power_of_two_axis("BlockSize", nvbench::range(6, 10))
159-
.add_int64_power_of_two_axis("NumBlocks", nvbench::range(6, 10))
160-
.user_iteration_axes({"NumBlocks", "BlockSize"},
161-
[](auto... args)
162-
-> std::unique_ptr<nvbench::axis_space_base> {
163-
return std::make_unique<under_diag>(args...);
164-
});
165-
156+
.user_iteration_axes(
157+
[](auto... args) -> std::unique_ptr<nvbench::axis_space_base> {
158+
return std::make_unique<under_diag>(args...);
159+
},
160+
nvbench::int64_axis("BlockSize", {64, 128, 256, 512, 1024}),
161+
nvbench::int64_axis("NumBlocks", {1024, 521, 256, 128, 64}));
166162

167163
//==============================================================================
168164
// gauss:
@@ -233,15 +229,13 @@ void dual_float64_axis(nvbench::state &state)
233229
});
234230
}
235231
NVBENCH_BENCH(dual_float64_axis)
236-
.add_float64_axis("Duration_A", nvbench::range(0., 1e-4, 1e-5))
237-
.add_float64_axis("Duration_B", nvbench::range(0., 1e-4, 1e-5))
238-
.user_iteration_axes({"Duration_A"},
239-
[](auto... args)
240-
-> std::unique_ptr<nvbench::axis_space_base> {
241-
return std::make_unique<gauss>(args...);
242-
})
243-
.user_iteration_axes({"Duration_B"},
244-
[](auto... args)
245-
-> std::unique_ptr<nvbench::axis_space_base> {
246-
return std::make_unique<gauss>(args...);
247-
});
232+
.user_iteration_axes(
233+
[](auto... args) -> std::unique_ptr<nvbench::axis_space_base> {
234+
return std::make_unique<gauss>(args...);
235+
},
236+
nvbench::float64_axis("Duration_A", nvbench::range(0., 1e-4, 1e-5)))
237+
.user_iteration_axes(
238+
[](auto... args) -> std::unique_ptr<nvbench::axis_space_base> {
239+
return std::make_unique<gauss>(args...);
240+
},
241+
nvbench::float64_axis("Duration_B", nvbench::range(0., 1e-4, 1e-5)));

nvbench/axes_metadata.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,26 @@ struct axes_metadata
6262

6363
void add_string_axis(std::string name, std::vector<std::string> data);
6464

65+
void add_axis(const axis_base& axis);
66+
67+
template<typename... Args>
68+
void zip_axes(Args&&... args)
69+
{
70+
(this->add_axis(std::forward<Args>(args)),...);
71+
this->zip_axes({args.get_name()...});
72+
}
73+
6574
void zip_axes(std::vector<std::string> names);
6675

76+
template<typename... Args>
77+
void
78+
user_iteration_axes(std::function<nvbench::make_user_space_signature> make,
79+
Args&&... args)
80+
{
81+
(this->add_axis(std::forward<Args>(args)),...);
82+
this->user_iteration_axes({args.get_name()...}, std::move(make));
83+
}
84+
6785
void
6886
user_iteration_axes(std::vector<std::string> names,
6987
std::function<nvbench::make_user_space_signature> make);

nvbench/axes_metadata.cxx

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -117,38 +117,28 @@ catch (std::exception &e)
117117
void axes_metadata::add_float64_axis(std::string name,
118118
std::vector<nvbench::float64_t> data)
119119
{
120-
m_value_space.push_back(
121-
std::make_unique<linear_axis_space>(m_axes.size(),
122-
m_axes.size() - m_type_axe_count));
123-
124-
auto axis = std::make_unique<nvbench::float64_axis>(std::move(name));
125-
axis->set_inputs(std::move(data));
126-
m_axes.push_back(std::move(axis));
120+
this->add_axis(nvbench::float64_axis{name,data});
127121
}
128122

129123
void axes_metadata::add_int64_axis(std::string name,
130124
std::vector<nvbench::int64_t> data,
131125
nvbench::int64_axis_flags flags)
132126
{
133-
m_value_space.push_back(
134-
std::make_unique<linear_axis_space>(m_axes.size(),
135-
m_axes.size() - m_type_axe_count));
136-
137-
auto axis = std::make_unique<nvbench::int64_axis>(std::move(name));
138-
axis->set_inputs(std::move(data), flags);
139-
m_axes.push_back(std::move(axis));
127+
this->add_axis(nvbench::int64_axis{name,data,flags});
140128
}
141129

142130
void axes_metadata::add_string_axis(std::string name,
143131
std::vector<std::string> data)
132+
{
133+
this->add_axis(nvbench::string_axis{name,data});
134+
}
135+
136+
void axes_metadata::add_axis(const axis_base& axis)
144137
{
145138
m_value_space.push_back(
146139
std::make_unique<linear_axis_space>(m_axes.size(),
147140
m_axes.size() - m_type_axe_count));
148-
149-
auto axis = std::make_unique<nvbench::string_axis>(std::move(name));
150-
axis->set_inputs(std::move(data));
151-
m_axes.push_back(std::move(axis));
141+
m_axes.push_back(axis.clone());
152142
}
153143

154144
namespace

nvbench/benchmark_base.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,26 @@ struct benchmark_base
111111
return *this;
112112
}
113113

114+
template<typename... Args>
115+
benchmark_base &zip_axes(Args&&... args)
116+
{
117+
m_axes.zip_axes(std::forward<Args>(args)...);
118+
return *this;
119+
}
120+
114121
benchmark_base &zip_axes(std::vector<std::string> names)
115122
{
116123
m_axes.zip_axes(std::move(names));
117124
return *this;
118125
}
119126

127+
template<typename... Args>
128+
benchmark_base &user_iteration_axes(Args&&... args)
129+
{
130+
m_axes.user_iteration_axes(std::forward<Args>(args)...);
131+
return *this;
132+
}
133+
120134
benchmark_base &
121135
user_iteration_axes(std::vector<std::string> names,
122136
std::function<nvbench::make_user_space_signature> make)

nvbench/float64_axis.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ struct float64_axis final : public axis_base
3434
, m_values{}
3535
{}
3636

37+
explicit float64_axis(std::string name, std::vector<nvbench::float64_t> inputs)
38+
: axis_base{std::move(name), axis_type::float64}
39+
, m_values{std::move(inputs)}
40+
{}
41+
3742
~float64_axis() final;
3843

3944
void set_inputs(std::vector<nvbench::float64_t> inputs)

nvbench/int64_axis.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ struct int64_axis final : public axis_base
5151
, m_flags{int64_axis_flags::none}
5252
{}
5353

54+
explicit int64_axis(std::string name,
55+
std::vector<int64_t> inputs,
56+
int64_axis_flags flags = int64_axis_flags::none);
57+
5458
~int64_axis() final;
5559

5660
[[nodiscard]] bool is_power_of_two() const

nvbench/int64_axis.cxx

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,24 @@
2626
#include <stdexcept>
2727
#include <vector>
2828

29-
namespace nvbench
29+
namespace
3030
{
3131

32-
int64_axis::~int64_axis() = default;
33-
34-
void int64_axis::set_inputs(std::vector<int64_t> inputs, int64_axis_flags flags)
32+
std::vector<nvbench::int64_t>
33+
construct_values(nvbench::int64_axis_flags flags,
34+
const std::vector<nvbench::int64_t> &inputs)
3535
{
36-
m_inputs = std::move(inputs);
37-
m_flags = flags;
3836

39-
if (!this->is_power_of_two())
37+
std::vector<int64_t> values;
38+
const bool is_power_of_two =
39+
static_cast<bool>(flags & nvbench::int64_axis_flags::power_of_two);
40+
if (!is_power_of_two)
4041
{
41-
m_values = m_inputs;
42+
values = inputs;
4243
}
4344
else
4445
{
45-
m_values.resize(m_inputs.size());
46+
values.resize(inputs.size());
4647

4748
auto conv = [](int64_t in) -> int64_t {
4849
if (in < 0 || in >= 64)
@@ -52,11 +53,35 @@ void int64_axis::set_inputs(std::vector<int64_t> inputs, int64_axis_flags flags)
5253
"Input={} ValidRange=[0, 63]",
5354
in);
5455
}
55-
return int64_axis::compute_pow2(in);
56+
return nvbench::int64_axis::compute_pow2(in);
5657
};
5758

58-
std::transform(m_inputs.cbegin(), m_inputs.cend(), m_values.begin(), conv);
59+
std::transform(inputs.cbegin(), inputs.cend(), values.begin(), conv);
5960
}
61+
62+
return values;
63+
}
64+
} // namespace
65+
66+
namespace nvbench
67+
{
68+
69+
int64_axis::int64_axis(std::string name,
70+
std::vector<int64_t> inputs,
71+
int64_axis_flags flags)
72+
: axis_base{std::move(name), axis_type::int64}
73+
, m_inputs{std::move(inputs)}
74+
, m_values{construct_values(flags, m_inputs)}
75+
, m_flags{flags}
76+
{}
77+
78+
int64_axis::~int64_axis() = default;
79+
80+
void int64_axis::set_inputs(std::vector<int64_t> inputs, int64_axis_flags flags)
81+
{
82+
m_inputs = std::move(inputs);
83+
m_flags = flags;
84+
m_values = construct_values(flags, m_inputs);
6085
}
6186

6287
std::string int64_axis::do_get_input_string(std::size_t i) const

nvbench/string_axis.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ struct string_axis final : public axis_base
3434
, m_values{}
3535
{}
3636

37+
explicit string_axis(std::string name, std::vector<std::string> inputs)
38+
: axis_base{std::move(name), axis_type::string}
39+
, m_values{std::move(inputs)}
40+
{}
41+
3742
~string_axis() final;
3843

3944
void set_inputs(std::vector<std::string> inputs)

testing/axes_iteration_space.cu

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,8 @@ void test_zip_axes()
8181
{
8282
using benchmark_type = nvbench::benchmark<no_op_callable>;
8383
benchmark_type bench;
84-
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
85-
bench.add_int64_axis("I64 Axis", {1, 3, 2, 4, 5});
86-
bench.zip_axes({"F64 Axis", "I64 Axis"});
84+
bench.zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
85+
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
8786

8887
ASSERT_MSG(bench.get_config_count() == 5 * bench.get_devices().size(),
8988
"Got {}",
@@ -107,11 +106,10 @@ void test_tie_unequal_length()
107106
{
108107
using benchmark_type = nvbench::benchmark<no_op_callable>;
109108
benchmark_type bench;
110-
bench.add_float64_axis("F64 Axis", {0., .1, .25, .5, 1.});
111-
bench.add_int64_axis("I64 Axis", {1, 3, 2});
112109

113-
bench.zip_axes({"I64 Axis", "F64 Axis"});
114-
ASSERT_THROWS_ANY(bench.zip_axes({"F64 Axis", "I64 Axis"}));
110+
ASSERT_THROWS_ANY(
111+
bench.zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
112+
nvbench::int64_axis("I64 Axis", {1, 3, 2})));
115113
}
116114

117115
void test_tie_type_axi()
@@ -191,11 +189,11 @@ void test_tie_clone()
191189
using benchmark_type = nvbench::benchmark<no_op_callable>;
192190
benchmark_type bench;
193191
bench.set_devices(std::vector<int>{});
194-
bench.add_string_axis("Strings", {"string a", "string b", "string c"});
195192
bench.add_int64_power_of_two_axis("I64 POT Axis", {10, 20});
196193
bench.add_int64_axis("I64 Axis", {10, 20});
197-
bench.add_float64_axis("F64 Axis", {0., .1, .25});
198-
bench.zip_axes({"F64 Axis", "Strings"});
194+
bench.zip_axes(nvbench::string_axis("Strings",
195+
{"string a", "string b", "string c"}),
196+
nvbench::float64_axis("F64 Axis", {0., .1, .25}));
199197

200198
const auto expected_count = bench.get_config_count();
201199

@@ -237,7 +235,8 @@ struct under_diag final : nvbench::user_axis_space
237235
{
238236
under_diag(std::vector<std::size_t> input_indices,
239237
std::vector<std::size_t> output_indices)
240-
: nvbench::user_axis_space(std::move(input_indices), std::move(output_indices))
238+
: nvbench::user_axis_space(std::move(input_indices),
239+
std::move(output_indices))
241240
{}
242241

243242
mutable std::size_t x_pos = 0;

0 commit comments

Comments
 (0)