Skip to content

Commit 003c1f1

Browse files
committed
Create stream in state.get_cuda_stream if needed.
This simplifies usage in the common case where the caller just want an expected stream. The const method `get_cuda_stream_optional` still returns the optional if this behavior is undesired.
1 parent e6df734 commit 003c1f1

File tree

5 files changed

+17
-23
lines changed

5 files changed

+17
-23
lines changed

nvbench/detail/measure_cold.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,7 @@ namespace nvbench::detail
3737

3838
measure_cold_base::measure_cold_base(state &exec_state)
3939
: m_state{exec_state}
40-
, m_launch{nvbench::launch([this]() -> decltype(auto) {
41-
if (!m_state.get_cuda_stream().has_value())
42-
{
43-
m_state.set_cuda_stream(nvbench::cuda_stream{m_state.get_device()});
44-
}
45-
return m_state.get_cuda_stream().value();
46-
}())}
40+
, m_launch{exec_state.get_cuda_stream()}
4741
, m_criterion_params{exec_state.get_criterion_params()}
4842
, m_stopping_criterion{nvbench::criterion_manager::get().get_criterion(
4943
exec_state.get_stopping_criterion())}

nvbench/detail/measure_cupti.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,7 @@ measure_cupti_base::measure_cupti_base(state &exec_state)
165165
// (formatter doesn't handle `try :` very well...)
166166
try
167167
: m_state{exec_state}
168-
, m_launch{[this]() -> decltype(auto) {
169-
if (!m_state.get_cuda_stream().has_value())
170-
{
171-
m_state.set_cuda_stream(nvbench::cuda_stream{m_state.get_device()});
172-
}
173-
return m_state.get_cuda_stream().value();
174-
}()}
168+
, m_launch{exec_state.get_cuda_stream()}
175169
, m_cupti{*m_state.get_device(), add_metrics(m_state)}
176170
{}
177171
// clang-format on

nvbench/detail/measure_hot.cu

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,7 @@ namespace nvbench::detail
3636

3737
measure_hot_base::measure_hot_base(state &exec_state)
3838
: m_state{exec_state}
39-
, m_launch{nvbench::launch([this]() -> decltype(auto) {
40-
if (!m_state.get_cuda_stream().has_value())
41-
{
42-
m_state.set_cuda_stream(nvbench::cuda_stream{m_state.get_device()});
43-
}
44-
return m_state.get_cuda_stream().value();
45-
}())}
39+
, m_launch{exec_state.get_cuda_stream()}
4640
, m_min_samples{exec_state.get_min_samples()}
4741
, m_min_time{exec_state.get_min_time()}
4842
, m_skip_time{exec_state.get_skip_time()}

nvbench/state.cuh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,18 @@ struct state
6464
state &operator=(const state &) = delete;
6565
state &operator=(state &&) = default;
6666

67-
[[nodiscard]] const std::optional<nvbench::cuda_stream> &get_cuda_stream() const
67+
/// If a stream exists, return that. Otherwise, create a new stream using the current
68+
/// device (or the current device if none is set), save it, and return it.
69+
/// @sa get_cuda_stream_optional
70+
[[nodiscard]] nvbench::cuda_stream &get_cuda_stream()
71+
{
72+
if (!m_cuda_stream.has_value())
73+
{
74+
m_cuda_stream = nvbench::cuda_stream{m_device};
75+
}
76+
return m_cuda_stream.value();
77+
}
78+
[[nodiscard]] const std::optional<nvbench::cuda_stream> &get_cuda_stream_optional() const
6879
{
6980
return m_cuda_stream;
7081
}

testing/state.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ void test_streams()
5656
state_tester state{bench};
5757

5858
// Confirm that the stream hasn't been initialized yet
59-
ASSERT(!state.get_cuda_stream().has_value());
59+
ASSERT(!state.get_cuda_stream_optional().has_value());
6060

6161
// Test non-owning stream
6262
cudaStream_t default_stream = 0;
6363
state.set_cuda_stream(nvbench::cuda_stream{default_stream, false});
64+
ASSERT(state.get_cuda_stream_optional() == default_stream);
6465
ASSERT(state.get_cuda_stream() == default_stream);
6566

6667
// Test owning stream

0 commit comments

Comments
 (0)