Skip to content

Commit 9d18928

Browse files
authored
Fix get_config_count for CPU-only benchmarks. (#218)
1 parent 433376f commit 9d18928

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

nvbench/benchmark_base.cxx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include <nvbench/criterion_manager.cuh>
2121
#include <nvbench/detail/transform_reduce.cuh>
2222

23+
#include <algorithm>
24+
#include <cstdint>
25+
2326
namespace nvbench
2427
{
2528

@@ -86,7 +89,8 @@ std::size_t benchmark_base::get_config_count() const
8689
return axis_ptr->get_size();
8790
});
8891

89-
return per_device_count * m_devices.size();
92+
// Devices will be empty for cpu-only benchmarks.
93+
return per_device_count * std::max(std::size_t(1), m_devices.size());
9094
}
9195

9296
benchmark_base &benchmark_base::set_stopping_criterion(std::string criterion)

testing/benchmark.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <nvbench/benchmark.cuh>
2020
#include <nvbench/callable.cuh>
21+
#include <nvbench/device_manager.cuh>
2122
#include <nvbench/named_values.cuh>
2223
#include <nvbench/state.cuh>
2324
#include <nvbench/type_list.cuh>
@@ -27,6 +28,7 @@
2728
#include <fmt/format.h>
2829

2930
#include <algorithm>
31+
#include <cstdint>
3032
#include <utility>
3133
#include <variant>
3234
#include <vector>
@@ -279,6 +281,7 @@ void test_clone()
279281
void test_get_config_count()
280282
{
281283
lots_of_types_bench bench;
284+
bench.set_devices(nvbench::device_manager::get().get_devices());
282285
bench.set_type_axes_names({"Integer", "Float", "Other"});
283286
bench.get_axes().get_type_axis(0).set_active_inputs({"I16", "I32"}); // 2, 2
284287
bench.get_axes().get_type_axis(1).set_active_inputs({"F32", "F64"}); // 2, 4
@@ -288,9 +291,13 @@ void test_get_config_count()
288291
bench.add_string_axis("baz", {"str", "ing"}); // 2, 72
289292
bench.add_string_axis("baz", {"single"}); // 1, 72
290293

291-
auto const num_devices = bench.get_devices().size();
294+
auto const num_devices = std::max(std::size_t(1), bench.get_devices().size());
292295

293296
ASSERT_MSG(bench.get_config_count() == 72 * num_devices, "Got {}", bench.get_config_count());
297+
298+
// Check that zero devices (e.g. CPU-only) is the same as a single device:
299+
bench.set_devices(std::vector<int>{});
300+
ASSERT_MSG(bench.get_config_count() == 72, "Got {}", bench.get_config_count());
294301
}
295302

296303
int main()

0 commit comments

Comments
 (0)