2424#include < fmt/ranges.h>
2525
2626#include < algorithm>
27+ #include < numeric>
2728#include < stdexcept>
2829#include < unordered_set>
2930
@@ -117,128 +118,45 @@ catch (std::exception &e)
117118void axes_metadata::add_float64_axis (std::string name,
118119 std::vector<nvbench::float64_t > data)
119120{
120- this ->add_axis (nvbench::float64_axis{name,data});
121+ this ->add_axis (nvbench::float64_axis{name, data});
121122}
122123
123124void axes_metadata::add_int64_axis (std::string name,
124125 std::vector<nvbench::int64_t > data,
125126 nvbench::int64_axis_flags flags)
126127{
127- this ->add_axis (nvbench::int64_axis{name,data,flags});
128+ this ->add_axis (nvbench::int64_axis{name, data, flags});
128129}
129130
130131void axes_metadata::add_string_axis (std::string name,
131132 std::vector<std::string> data)
132133{
133- this ->add_axis (nvbench::string_axis{name,data});
134+ this ->add_axis (nvbench::string_axis{name, data});
134135}
135136
136- void axes_metadata::add_axis (const axis_base& axis)
137+ void axes_metadata::add_axis (const axis_base & axis)
137138{
138139 m_value_space.push_back (
139140 std::make_unique<linear_axis_space>(m_axes.size (),
140141 m_axes.size () - m_type_axe_count));
141142 m_axes.push_back (axis.clone ());
142143}
143144
144- namespace
145+ void axes_metadata::add_zip_space (std:: size_t first_index, std:: size_t count)
145146{
146- std::tuple<std::vector<std::size_t >, std::vector<std::size_t >>
147- get_axes_indices (std::size_t type_axe_count,
148- const nvbench::axes_metadata::axes_type &axes,
149- const std::vector<std::string> &names)
150- {
151- std::vector<std::size_t > input_indices;
152- input_indices.reserve (names.size ());
153- for (auto &n : names)
154- {
155- auto iter =
156- std::find_if (axes.cbegin (), axes.cend (), [&n](const auto &axis) {
157- return axis->get_name () == n;
158- });
159-
160- // iter distance is input_indices
161- if (iter == axes.cend ())
162- {
163- NVBENCH_THROW (std::runtime_error,
164- " Unable to find the axes named ({})." ,
165- n);
166- }
167- auto index = std::distance (axes.cbegin (), iter);
168- input_indices.push_back (index);
169- }
170-
171- std::vector<std::size_t > output_indices = input_indices;
172- for (auto &out : output_indices)
173- {
174- out -= type_axe_count;
175- }
176- return {std::move (input_indices), std::move (output_indices)};
177- }
178-
179- void reset_iteration_space (
180- nvbench::axes_metadata::iteration_space_type &all_spaces,
181- const std::vector<std::size_t > &indices_to_remove)
182- {
183- // 1. Find all spaces indices that
184- nvbench::axes_metadata::iteration_space_type reset_space;
185- nvbench::axes_metadata::iteration_space_type to_filter;
186- for (auto &space : all_spaces)
187- {
188- bool added = false ;
189- for (auto &i : indices_to_remove)
190- {
191- if (space->contains (i))
192- {
193- // add each item back as linear_axis_space
194- auto as_linear = space->clone_as_linear ();
195- to_filter.insert (to_filter.end (),
196- std::make_move_iterator (as_linear.begin ()),
197- std::make_move_iterator (as_linear.end ()));
198- added = true ;
199- break ;
200- }
201- }
202- if (!added)
203- {
204- // this space doesn't need to be removed
205- reset_space.push_back (std::move (space));
206- }
207- }
208-
209- for (auto &iter : to_filter)
210- {
211- bool to_add = true ;
212- for (auto &i : indices_to_remove)
213- {
214- if (iter->contains (i))
215- {
216- to_add = false ;
217- break ;
218- }
219- }
220- if (to_add)
221- {
222- reset_space.push_back (std::move (iter));
223- break ;
224- }
225- }
226-
227- all_spaces = std::move (reset_space);
228- }
229- } // namespace
230-
231- void axes_metadata::zip_axes (std::vector<std::string> names)
232- {
233- NVBENCH_THROW_IF ((names.size () < 2 ),
147+ NVBENCH_THROW_IF ((count < 2 ),
234148 std::runtime_error,
235- " At least two axi names ( {} provided ) need to be provided "
149+ " At least two axi ( {} provided ) need to be provided "
236150 " when using zip_axes." ,
237- names. size () );
151+ count );
238152
239153 // compute the numeric indice for each name we have
240- auto [input_indices,
241- output_indices] = get_axes_indices (m_type_axe_count, m_axes, names);
154+ std::vector<std::size_t > input_indices (count);
155+ std::vector<std::size_t > output_indices (count);
156+ std::iota (input_indices.begin (), input_indices.end (), first_index);
157+ std::iota (input_indices.begin (),
158+ input_indices.end (),
159+ first_index - m_type_axe_count);
242160
243161 const auto expected_size = m_axes[input_indices[0 ]]->get_size ();
244162 for (auto i : input_indices)
@@ -255,22 +173,24 @@ void axes_metadata::zip_axes(std::vector<std::string> names)
255173 expected_size);
256174 }
257175
258- // remove any iteration spaces that have axes we need
259- reset_iteration_space (m_value_space, input_indices);
260-
261176 // add the new tied iteration space
262177 auto tied = std::make_unique<zip_axis_space>(std::move (input_indices),
263178 std::move (output_indices));
264179 m_value_space.push_back (std::move (tied));
265180}
266181
267- void axes_metadata::user_iteration_axes (
182+ void axes_metadata::add_user_iteration_space (
268183 std::function<nvbench::make_user_space_signature> make,
269- std::vector<std::string> names)
184+ std::size_t first_index,
185+ std::size_t count)
270186{
271187 // compute the numeric indice for each name we have
272- auto [input_indices,
273- output_indices] = get_axes_indices (m_type_axe_count, m_axes, names);
188+ std::vector<std::size_t > input_indices (count);
189+ std::vector<std::size_t > output_indices (count);
190+ std::iota (input_indices.begin (), input_indices.end (), first_index);
191+ std::iota (input_indices.begin (),
192+ input_indices.end (),
193+ first_index - m_type_axe_count);
274194
275195 for (auto i : input_indices)
276196 {
@@ -281,9 +201,6 @@ void axes_metadata::user_iteration_axes(
281201 m_axes[i]->get_name ());
282202 }
283203
284- // remove any iteration spaces that have axes we need
285- reset_iteration_space (m_value_space, input_indices);
286-
287204 auto user_func = make (std::move (input_indices), std::move (output_indices));
288205 m_value_space.push_back (std::move (user_func));
289206}
0 commit comments