|
17 | 17 | */ |
18 | 18 |
|
19 | 19 | #include <nvbench/benchmark.cuh> |
20 | | - |
21 | 20 | #include <nvbench/callable.cuh> |
| 21 | +#include <nvbench/device_manager.cuh> |
22 | 22 | #include <nvbench/named_values.cuh> |
23 | 23 | #include <nvbench/state.cuh> |
24 | 24 | #include <nvbench/type_list.cuh> |
25 | 25 | #include <nvbench/type_strings.cuh> |
26 | 26 | #include <nvbench/types.cuh> |
27 | 27 |
|
28 | | -#include "test_asserts.cuh" |
29 | | - |
30 | 28 | #include <fmt/format.h> |
31 | 29 |
|
32 | 30 | #include <algorithm> |
| 31 | +#include <iterator> |
33 | 32 | #include <utility> |
34 | 33 | #include <variant> |
35 | 34 | #include <vector> |
36 | 35 |
|
| 36 | +#include "test_asserts.cuh" |
| 37 | + |
37 | 38 | template <typename T> |
38 | 39 | std::vector<T> sort(std::vector<T> &&vec) |
39 | 40 | { |
@@ -114,12 +115,18 @@ void test_zip_axes() |
114 | 115 | { |
115 | 116 | using benchmark_type = nvbench::benchmark<no_op_callable>; |
116 | 117 | benchmark_type bench; |
| 118 | + bench.set_devices(nvbench::device_manager::get().get_devices()); |
117 | 119 | bench.add_zip_axes(nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}), |
118 | 120 | nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5})); |
119 | 121 |
|
120 | | - ASSERT_MSG(bench.get_config_count() == 5 * bench.get_devices().size(), |
121 | | - "Got {}", |
122 | | - bench.get_config_count()); |
| 122 | + const auto num_devices = std::max(std::size_t(1), bench.get_devices().size()); |
| 123 | + ASSERT_MSG(bench.get_config_count() == 5 * num_devices, |
| 124 | + "Got {}, expected {}", |
| 125 | + bench.get_config_count(), |
| 126 | + 5 * bench.get_devices().size()); |
| 127 | + |
| 128 | + bench.set_devices(std::vector<int>{}); |
| 129 | + ASSERT_MSG(bench.get_config_count() == 5, "Got {}, expected {}", bench.get_config_count(), 5); |
123 | 130 | } |
124 | 131 |
|
125 | 132 | void test_zip_unequal_length() |
@@ -241,16 +248,19 @@ void test_user_axes() |
241 | 248 | { |
242 | 249 | using benchmark_type = rezippable_benchmark<no_op_callable>; |
243 | 250 | benchmark_type bench; |
| 251 | + bench.set_devices(nvbench::device_manager::get().get_devices()); |
244 | 252 | bench.add_user_iteration_axes( |
245 | 253 | [](auto... args) -> std::unique_ptr<nvbench::iteration_space_base> { |
246 | 254 | return std::make_unique<under_diag>(args...); |
247 | 255 | }, |
248 | 256 | nvbench::float64_axis("F64 Axis", {0., .1, .25, .5, 1.}), |
249 | 257 | nvbench::int64_axis("I64 Axis", {1, 3, 2, 4, 5})); |
250 | 258 |
|
251 | | - ASSERT_MSG(bench.get_config_count() == 15 * bench.get_devices().size(), |
252 | | - "Got {}", |
253 | | - bench.get_config_count()); |
| 259 | + const auto num_devices = std::max(std::size_t(1), bench.get_devices().size()); |
| 260 | + ASSERT_MSG(bench.get_config_count() == 15 * num_devices, "Got {}", bench.get_config_count()); |
| 261 | + |
| 262 | + bench.set_devices(std::vector<int>{}); |
| 263 | + ASSERT_MSG(bench.get_config_count() == 15, "Got {}", bench.get_config_count()); |
254 | 264 | } |
255 | 265 |
|
256 | 266 | int main() |
|
0 commit comments