Skip to content

Commit c8909c7

Browse files
committed
Refactoring / renaming.
1 parent a2bf266 commit c8909c7

13 files changed

+231
-199
lines changed

examples/custom_iteration_spaces.cu

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ struct under_diag final : nvbench::user_axis_space
9696
mutable std::size_t y_pos = 0;
9797
mutable std::size_t x_start = 0;
9898

99-
nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
99+
nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
100100
{
101101
// generate our increment function
102102
auto adv_func = [&, info](std::size_t &inc_index, std::size_t /*len*/) -> bool {
103103
inc_index++;
104104
x_pos++;
105-
if (x_pos == info[0].size)
105+
if (x_pos == info[0].axis_size)
106106
{
107107
x_pos = ++x_start;
108108
y_pos = x_start;
@@ -112,25 +112,24 @@ struct under_diag final : nvbench::user_axis_space
112112
};
113113

114114
// our update function
115-
auto diag_under = [&, info](std::size_t,
116-
std::vector<nvbench::detail::axis_index>::iterator start,
117-
std::vector<nvbench::detail::axis_index>::iterator end) {
118-
start->index = x_pos;
119-
end->index = y_pos;
120-
};
115+
auto diag_under =
116+
[&, info](std::size_t, axis_value_indices::iterator start, axis_value_indices::iterator end) {
117+
start->value_index = x_pos;
118+
end->value_index = y_pos;
119+
};
121120

122-
const size_t iteration_length = ((info[0].size * (info[1].size + 1)) / 2);
121+
const size_t iteration_length = ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
123122
return nvbench::detail::axis_space_iterator(info, iteration_length, adv_func, diag_under);
124123
}
125124

126-
std::size_t do_get_size(const axes_info &info) const
125+
std::size_t do_get_size(const axis_value_indices &info) const
127126
{
128-
return ((info[0].size * (info[1].size + 1)) / 2);
127+
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
129128
}
130129

131-
std::size_t do_get_active_count(const axes_info &info) const
130+
std::size_t do_get_active_count(const axis_value_indices &info) const
132131
{
133-
return ((info[0].size * (info[1].size + 1)) / 2);
132+
return ((info[0].axis_size * (info[1].axis_size + 1)) / 2);
134133
}
135134

136135
std::unique_ptr<nvbench::iteration_space_base> do_clone() const
@@ -160,36 +159,38 @@ struct gauss final : nvbench::user_axis_space
160159
: nvbench::user_axis_space(std::move(input_indices))
161160
{}
162161

163-
nvbench::detail::axis_space_iterator do_get_iterator(axes_info info) const
162+
nvbench::detail::axis_space_iterator do_get_iterator(axis_value_indices info) const
164163
{
165-
const double mid_point = static_cast<double>((info[0].size / 2));
164+
const double mid_point = static_cast<double>((info[0].axis_size / 2));
166165

167166
std::random_device rd{};
168167
std::mt19937 gen{rd()};
169168
std::normal_distribution<> d{mid_point, 2};
170169

171-
const size_t iteration_length = info[0].size;
170+
const size_t iteration_length = info[0].axis_size;
172171
std::vector<std::size_t> gauss_indices(iteration_length);
173172
for (auto &g : gauss_indices)
174173
{
175-
auto v = std::min(static_cast<double>(info[0].size), d(gen));
174+
auto v = std::min(static_cast<double>(info[0].axis_size), d(gen));
176175
v = std::max(0.0, v);
177176
g = static_cast<std::size_t>(v);
178177
}
179178

180179
// our update function
181-
auto gauss_func = [=](std::size_t index,
182-
std::vector<nvbench::detail::axis_index>::iterator start,
183-
std::vector<nvbench::detail::axis_index>::iterator) {
184-
start->index = gauss_indices[index];
185-
};
180+
auto gauss_func =
181+
[=](std::size_t index, axis_value_indices::iterator start, axis_value_indices::iterator) {
182+
start->value_index = gauss_indices[index];
183+
};
186184

187185
return nvbench::detail::axis_space_iterator(info, iteration_length, gauss_func);
188186
}
189187

190-
std::size_t do_get_size(const axes_info &info) const { return info[0].size; }
188+
std::size_t do_get_size(const axis_value_indices &info) const { return info[0].axis_size; }
191189

192-
std::size_t do_get_active_count(const axes_info &info) const { return info[0].size; }
190+
std::size_t do_get_active_count(const axis_value_indices &info) const
191+
{
192+
return info[0].axis_size;
193+
}
193194

194195
std::unique_ptr<iteration_space_base> do_clone() const { return std::make_unique<gauss>(*this); }
195196
};

nvbench/detail/axis_space_iterator.cuh

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,76 +30,86 @@ namespace nvbench
3030
namespace detail
3131
{
3232

33-
struct axis_index
33+
// Tracks current value and axis information used while iterating through axes.
34+
struct axis_value_index
3435
{
35-
axis_index() = default;
36-
37-
explicit axis_index(const axis_base *axis)
38-
: index(0)
39-
, name(axis->get_name())
40-
, type(axis->get_type())
41-
, size(axis->get_size())
42-
, active_size(axis->get_size())
43-
{
44-
if (type == nvbench::axis_type::type)
45-
{
46-
active_size = static_cast<const nvbench::type_axis *>(axis)->get_active_count();
47-
}
48-
}
49-
std::size_t index;
50-
std::string name;
51-
nvbench::axis_type type;
52-
std::size_t size;
53-
std::size_t active_size;
36+
axis_value_index() = default;
37+
38+
explicit axis_value_index(const axis_base *axis)
39+
: value_index(0)
40+
, axis_name(axis->get_name())
41+
, axis_type(axis->get_type())
42+
, axis_size(axis->get_size())
43+
, axis_active_size(axis_type == nvbench::axis_type::type
44+
? static_cast<const nvbench::type_axis *>(axis)->get_active_count()
45+
: axis->get_size())
46+
{}
47+
48+
std::size_t value_index;
49+
std::string axis_name;
50+
nvbench::axis_type axis_type;
51+
std::size_t axis_size;
52+
std::size_t axis_active_size;
5453
};
5554

5655
struct axis_space_iterator
5756
{
58-
using axes_info = std::vector<detail::axis_index>;
59-
using AdvanceSignature = bool(std::size_t &current_index, std::size_t length);
60-
using UpdateSignature = void(std::size_t index,
61-
axes_info::iterator start,
62-
axes_info::iterator end);
57+
using axis_value_indices = std::vector<detail::axis_value_index>;
58+
using advance_signature = bool(std::size_t &current_iteration, std::size_t iteration_size);
59+
using update_signature = void(std::size_t current_iteration,
60+
axis_value_indices::iterator start_axis_value_info,
61+
axis_value_indices::iterator end_axis_value_info);
6362

64-
axis_space_iterator(std::vector<detail::axis_index> info,
65-
std::size_t iter_count,
66-
std::function<axis_space_iterator::AdvanceSignature> &&advance,
67-
std::function<axis_space_iterator::UpdateSignature> &&update)
68-
: m_info(info)
69-
, m_iteration_size(iter_count)
63+
axis_space_iterator(axis_value_indices info,
64+
std::size_t iteration_size,
65+
std::function<axis_space_iterator::advance_signature> &&advance,
66+
std::function<axis_space_iterator::update_signature> &&update)
67+
: m_iteration_size(iteration_size)
68+
, m_axis_value_indices(std::move(info))
7069
, m_advance(std::move(advance))
7170
, m_update(std::move(update))
7271
{}
7372

74-
axis_space_iterator(std::vector<detail::axis_index> info,
73+
axis_space_iterator(axis_value_indices info,
7574
std::size_t iter_count,
76-
std::function<axis_space_iterator::UpdateSignature> &&update)
77-
: m_info(info)
78-
, m_iteration_size(iter_count)
75+
std::function<axis_space_iterator::update_signature> &&update)
76+
: m_iteration_size(iter_count)
77+
, m_axis_value_indices(std::move(info))
7978
, m_update(std::move(update))
8079
{}
8180

82-
[[nodiscard]] bool next() { return this->m_advance(m_current_index, m_iteration_size); }
81+
[[nodiscard]] bool next() { return m_advance(m_current_iteration, m_iteration_size); }
8382

84-
void update_indices(std::vector<axis_index> &indices) const
83+
void update_axis_value_indices(axis_value_indices &info) const
8584
{
86-
using diff_t = typename axes_info::difference_type;
87-
indices.insert(indices.end(), m_info.begin(), m_info.end());
88-
axes_info::iterator end = indices.end();
89-
axes_info::iterator start = end - static_cast<diff_t>(m_info.size());
90-
this->m_update(m_current_index, start, end);
85+
using diff_t = typename axis_value_indices::difference_type;
86+
info.insert(info.end(), m_axis_value_indices.begin(), m_axis_value_indices.end());
87+
axis_value_indices::iterator end = info.end();
88+
axis_value_indices::iterator start = end - static_cast<diff_t>(m_axis_value_indices.size());
89+
m_update(m_current_iteration, start, end);
9190
}
9291

93-
axes_info m_info;
94-
std::size_t m_iteration_size = 1;
95-
std::function<AdvanceSignature> m_advance = [](std::size_t &current_index, std::size_t length) {
96-
(current_index + 1 == length) ? current_index = 0 : current_index++;
97-
return (current_index == 0); // we rolled over
98-
};
99-
std::function<UpdateSignature> m_update = nullptr;
92+
[[nodiscard]] const axis_value_indices &get_axis_value_indices() const
93+
{
94+
return m_axis_value_indices;
95+
}
96+
[[nodiscard]] axis_value_indices &get_axis_value_indices() { return m_axis_value_indices; }
97+
98+
[[nodiscard]] std::size_t get_iteration_size() const { return m_iteration_size; }
10099

101100
private:
102-
std::size_t m_current_index = 0;
101+
std::size_t m_current_iteration = 0;
102+
std::size_t m_iteration_size = 1;
103+
104+
axis_value_indices m_axis_value_indices;
105+
106+
std::function<advance_signature> m_advance = [](std::size_t &current_iteration,
107+
std::size_t iteration_size) {
108+
(current_iteration + 1 == iteration_size) ? current_iteration = 0 : current_iteration++;
109+
return (current_iteration == 0); // we rolled over
110+
};
111+
112+
std::function<update_signature> m_update = nullptr;
103113
};
104114

105115
} // namespace detail

nvbench/detail/state_generator.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,12 @@ struct state_iterator
7878

7979
[[nodiscard]] std::size_t get_number_of_states() const;
8080
void init();
81-
[[nodiscard]] std::vector<axis_index> get_current_indices() const;
81+
[[nodiscard]] std::vector<axis_value_index> get_current_axis_value_indices() const;
8282
[[nodiscard]] bool iter_valid() const;
8383
void next();
8484

85-
std::vector<axis_space_iterator> m_space;
85+
std::vector<axis_space_iterator> m_axis_space_iterators;
8686
std::size_t m_axes_count = 0;
87-
std::size_t m_current_space = 0;
8887
std::size_t m_current_iteration = 0;
8988
std::size_t m_max_iteration = 1;
9089
};

nvbench/detail/state_generator.cxx

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818

1919
#include <nvbench/benchmark_base.cuh>
2020
#include <nvbench/detail/state_generator.cuh>
21+
#include <nvbench/detail/throw.cuh>
2122
#include <nvbench/detail/transform_reduce.cuh>
2223
#include <nvbench/device_info.cuh>
2324
#include <nvbench/named_values.cuh>
2425
#include <nvbench/type_axis.cuh>
2526

2627
#include <algorithm>
2728
#include <cassert>
29+
#include <exception>
2830
#include <functional>
2931
#include <numeric>
3032

@@ -34,33 +36,37 @@ namespace nvbench::detail
3436

3537
void state_iterator::add_iteration_space(const nvbench::detail::axis_space_iterator &iter)
3638
{
37-
m_axes_count += iter.m_info.size();
38-
m_max_iteration *= iter.m_iteration_size;
39+
m_axes_count += iter.get_axis_value_indices().size();
40+
m_max_iteration *= iter.get_iteration_size();
3941

40-
m_space.push_back(std::move(iter));
42+
m_axis_space_iterators.push_back(std::move(iter));
4143
}
4244

4345
[[nodiscard]] std::size_t state_iterator::get_number_of_states() const
4446
{
4547
return this->m_max_iteration;
4648
}
4749

48-
void state_iterator::init()
49-
{
50-
m_current_space = 0;
51-
m_current_iteration = 0;
52-
}
50+
void state_iterator::init() { m_current_iteration = 0; }
5351

54-
[[nodiscard]] std::vector<axis_index> state_iterator::get_current_indices() const
52+
[[nodiscard]] std::vector<axis_value_index> state_iterator::get_current_axis_value_indices() const
5553
{
56-
std::vector<axis_index> indices;
57-
indices.reserve(m_axes_count);
58-
for (auto &m : m_space)
54+
std::vector<axis_value_index> info;
55+
info.reserve(m_axes_count);
56+
for (auto &iter : m_axis_space_iterators)
57+
{
58+
iter.update_axis_value_indices(info);
59+
}
60+
61+
if (info.size() != m_axes_count)
5962
{
60-
m.update_indices(indices);
63+
NVBENCH_THROW(std::runtime_error,
64+
"Internal error: State iterator has {} axes, but only {} were updated.",
65+
m_axes_count,
66+
info.size());
6167
}
62-
// verify length
63-
return indices;
68+
69+
return info;
6470
}
6571

6672
[[nodiscard]] bool state_iterator::iter_valid() const
@@ -72,9 +78,9 @@ void state_iterator::next()
7278
{
7379
m_current_iteration++;
7480

75-
for (auto &&space : this->m_space)
81+
for (auto &iter : this->m_axis_space_iterators)
7682
{
77-
auto rolled_over = space.next();
83+
const auto rolled_over = iter.next();
7884
if (rolled_over)
7985
{
8086
continue;
@@ -128,13 +134,13 @@ void state_generator::build_axis_configs()
128134
auto &[config, active_mask] =
129135
m_type_axis_configs.emplace_back(std::make_pair(nvbench::named_values{}, true));
130136

131-
for (const auto &axis_info : ti.get_current_indices())
137+
for (const auto &info : ti.get_current_axis_value_indices())
132138
{
133-
const auto &axis = axes.get_type_axis(axis_info.name);
139+
const auto &axis = axes.get_type_axis(info.axis_name);
134140

135-
active_mask &= axis.get_is_active(axis_info.index);
141+
active_mask &= axis.get_is_active(info.value_index);
136142

137-
config.set_string(axis.get_name(), axis.get_input_string(axis_info.index));
143+
config.set_string(axis.get_name(), axis.get_input_string(info.value_index));
138144
}
139145
}
140146

@@ -143,30 +149,33 @@ void state_generator::build_axis_configs()
143149
auto &config = m_non_type_axis_configs.emplace_back();
144150

145151
// Add non-type parameters to state:
146-
for (const auto &axis_info : vi.get_current_indices())
152+
for (const auto &axis_value : vi.get_current_axis_value_indices())
147153
{
148-
switch (axis_info.type)
154+
switch (axis_value.axis_type)
149155
{
150156
default:
151157
case axis_type::type:
152158
assert("unreachable." && false);
153159
break;
154160
case axis_type::int64:
155-
config.set_int64(axis_info.name,
156-
axes.get_int64_axis(axis_info.name).get_value(axis_info.index));
161+
config.set_int64(
162+
axis_value.axis_name,
163+
axes.get_int64_axis(axis_value.axis_name).get_value(axis_value.value_index));
157164
break;
158165

159166
case axis_type::float64:
160-
config.set_float64(axis_info.name,
161-
axes.get_float64_axis(axis_info.name).get_value(axis_info.index));
167+
config.set_float64(
168+
axis_value.axis_name,
169+
axes.get_float64_axis(axis_value.axis_name).get_value(axis_value.value_index));
162170
break;
163171

164172
case axis_type::string:
165-
config.set_string(axis_info.name,
166-
axes.get_string_axis(axis_info.name).get_value(axis_info.index));
173+
config.set_string(
174+
axis_value.axis_name,
175+
axes.get_string_axis(axis_value.axis_name).get_value(axis_value.value_index));
167176
break;
168177
} // switch (type)
169-
} // for (axis_info : current_indices)
178+
} // for (axis_values)
170179
}
171180

172181
if (m_type_axis_configs.empty())

0 commit comments

Comments
 (0)