@@ -54,8 +54,8 @@ struct measure_cold_base
5454 measure_cold_base &operator =(measure_cold_base &&) = delete ;
5555
5656protected:
57- template <bool use_blocking_kernel>
5857 struct kernel_launch_timer ;
58+ friend struct kernel_launch_timer ;
5959
6060 void check ();
6161 void initialize ();
@@ -89,8 +89,8 @@ protected:
8989 nvbench::criterion_params m_criterion_params;
9090 nvbench::stopping_criterion_base& m_stopping_criterion;
9191
92+ bool m_disable_blocking_kernel{false };
9293 bool m_run_once{false };
93- bool m_no_block{false };
9494
9595 nvbench::int64_t m_min_samples{};
9696
@@ -108,23 +108,23 @@ protected:
108108 bool m_max_time_exceeded{};
109109};
110110
111- template <bool use_blocking_kernel>
112111struct measure_cold_base ::kernel_launch_timer
113112{
114113 kernel_launch_timer (measure_cold_base &measure)
115114 : m_measure{measure}
115+ , m_disable_blocking_kernel{measure.m_disable_blocking_kernel }
116116 {}
117117
118118 __forceinline__ void start ()
119119 {
120120 m_measure.flush_device_l2 ();
121121 m_measure.sync_stream ();
122- if constexpr (use_blocking_kernel )
122+ if (!m_disable_blocking_kernel )
123123 {
124124 m_measure.block_stream ();
125125 }
126126 m_measure.m_cuda_timer .start (m_measure.m_launch .get_stream ());
127- if constexpr (!use_blocking_kernel )
127+ if (m_disable_blocking_kernel )
128128 {
129129 m_measure.m_cpu_timer .start ();
130130 }
@@ -133,7 +133,7 @@ struct measure_cold_base::kernel_launch_timer
133133 __forceinline__ void stop ()
134134 {
135135 m_measure.m_cuda_timer .stop (m_measure.m_launch .get_stream ());
136- if constexpr (use_blocking_kernel )
136+ if (!m_disable_blocking_kernel )
137137 {
138138 m_measure.m_cpu_timer .start ();
139139 m_measure.unblock_stream ();
@@ -144,9 +144,10 @@ struct measure_cold_base::kernel_launch_timer
144144
145145private:
146146 measure_cold_base &m_measure;
147+ bool m_disable_blocking_kernel;
147148};
148149
149- template <typename KernelLauncher, bool use_blocking_kernel >
150+ template <typename KernelLauncher>
150151struct measure_cold : public measure_cold_base
151152{
152153 measure_cold (nvbench::state &state, KernelLauncher &kernel_launcher)
@@ -177,15 +178,15 @@ private:
177178 return ;
178179 }
179180
180- kernel_launch_timer<use_blocking_kernel> timer (*this );
181+ kernel_launch_timer timer (*this );
181182
182183 this ->launch_kernel (timer);
183184 this ->check_skip_time (m_cuda_timer.get_duration ());
184185 }
185186
186187 void run_trials ()
187188 {
188- kernel_launch_timer<use_blocking_kernel> timer (*this );
189+ kernel_launch_timer timer (*this );
189190 do
190191 {
191192 this ->launch_kernel (timer);
0 commit comments