Skip to content

Commit 250d755

Browse files
committed
Update new test to support device-init changes.
1 parent edefcd0 commit 250d755

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

testing/axes_iteration_space.cu

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,24 @@
1717
*/
1818

1919
#include <nvbench/benchmark.cuh>
20-
2120
#include <nvbench/callable.cuh>
21+
#include <nvbench/device_manager.cuh>
2222
#include <nvbench/named_values.cuh>
2323
#include <nvbench/state.cuh>
2424
#include <nvbench/type_list.cuh>
2525
#include <nvbench/type_strings.cuh>
2626
#include <nvbench/types.cuh>
2727

28-
#include "test_asserts.cuh"
29-
3028
#include <fmt/format.h>
3129

3230
#include <algorithm>
31+
#include <iterator>
3332
#include <utility>
3433
#include <variant>
3534
#include <vector>
3635

36+
#include "test_asserts.cuh"
37+
3738
template <typename T>
3839
std::vector<T> sort(std::vector<T> &&vec)
3940
{
@@ -114,12 +115,18 @@ void test_zip_axes()
114115
{
115116
using benchmark_type = nvbench::benchmark<no_op_callable>;
116117
benchmark_type bench;
118+
bench.set_devices(nvbench::device_manager::get().get_devices());
117119
bench.add_zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
118120
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
119121

120-
ASSERT_MSG(bench.get_config_count() == 5 * bench.get_devices().size(),
121-
"Got {}",
122-
bench.get_config_count());
122+
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
123+
ASSERT_MSG(bench.get_config_count() == 5 * num_devices,
124+
"Got {}, expected {}",
125+
bench.get_config_count(),
126+
5 * bench.get_devices().size());
127+
128+
bench.set_devices(std::vector<int>{});
129+
ASSERT_MSG(bench.get_config_count() == 5, "Got {}, expected {}", bench.get_config_count(), 5);
123130
}
124131

125132
void test_zip_unequal_length()
@@ -241,16 +248,19 @@ void test_user_axes()
241248
{
242249
using benchmark_type = rezippable_benchmark<no_op_callable>;
243250
benchmark_type bench;
251+
bench.set_devices(nvbench::device_manager::get().get_devices());
244252
bench.add_user_iteration_axes(
245253
[](auto... args) -> std::unique_ptr<nvbench::iteration_space_base> {
246254
return std::make_unique<under_diag>(args...);
247255
},
248256
nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}),
249257
nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5}));
250258

251-
ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(),
252-
"Got {}",
253-
bench.get_config_count());
259+
const auto num_devices = std::max(std::size_t(1), bench.get_devices().size());
260+
ASSERT_MSG(bench.get_config_count() == 15 * num_devices, "Got {}", bench.get_config_count());
261+
262+
bench.set_devices(std::vector<int>{});
263+
ASSERT_MSG(bench.get_config_count() == 15, "Got {}", bench.get_config_count());
254264
}
255265

256266
int main()

0 commit comments

Comments
 (0)