Skip to content

Commit 0064436

Browse files
q10facebook-github-bot
authored andcommitted
Add kernel execution timing to the KernelLauncher class (pytorch#4201)
Summary: Pull Request resolved: pytorch#4201 - Add kernel execution timing to the `KernelLauncher` class Reviewed By: jiawenliu64 Differential Revision: D75382325 fbshipit-source-id: 216a5d57d8f0410e5d58ffddcf509255057a5d50
1 parent 3a8ab50 commit 0064436

File tree

3 files changed

+121
-28
lines changed

3 files changed

+121
-28
lines changed

fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "fbgemm_gpu/utils/device_cache_flusher.cuh"
2020
#include "fbgemm_gpu/utils/host_device_buffer_pair.cuh"
21+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2122
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
2223

2324
namespace fbgemm_gpu {
@@ -165,31 +166,14 @@ void time_kernel_run(
165166
Args&&... args) {
166167
std::cout << "[" << description << "] starting kernel run ..." << std::endl;
167168

168-
// Create CUDA events to time the kernel
169-
cudaEvent_t start, stop;
170-
cudaEventCreate(&start);
171-
cudaEventCreate(&stop);
172-
173-
// Execute the kernel, while recording the start and end times
174-
cudaEventRecord(start);
175-
kernel<<<grid, block>>>(std::forward<Args>(args)...);
176-
cudaEventRecord(stop);
177-
178-
// Synchronize to ensure that the kernel has completed
179-
C10_CUDA_KERNEL_LAUNCH_CHECK();
180-
cudaEventSynchronize(stop);
181-
182-
// Check for kernel execution errors
183-
const auto e = cudaGetLastError();
184-
if (e != cudaSuccess) {
185-
std::cout << "[" << description
186-
<< "] CUDA Failure: " << cudaGetErrorString(e) << std::endl;
187-
std::exit(-1);
188-
}
189-
190-
// Calculate the elapsed time in milliseconds
191-
float milliseconds = 0;
192-
cudaEventElapsedTime(&milliseconds, start, stop);
169+
const auto kernel_ = kernel;
170+
const auto milliseconds = FBGEMM_TIME_KERNEL_RUN(
171+
kernel_,
172+
grid,
173+
block,
174+
0,
175+
at::cuda::getCurrentCUDAStream(),
176+
std::forward<Args>(args)...);
193177

194178
std::cout << "[" << description << "] " << milliseconds << " ms\n"
195179
<< std::endl;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/ATen.h>
12+
#include <ATen/cuda/CUDAContext.h>
13+
#include <c10/cuda/CUDAStream.h>
14+
15+
#include <stdexcept>
16+
17+
namespace fbgemm_gpu::utils {
18+
19+
class KernelExecutionTimer {
20+
public:
21+
explicit KernelExecutionTimer(const c10::cuda::CUDAStream stream)
22+
: stream_(stream) {
23+
C10_CUDA_CHECK(cudaEventCreate(&start_));
24+
C10_CUDA_CHECK(cudaEventCreate(&stop_));
25+
}
26+
27+
KernelExecutionTimer(const KernelExecutionTimer&) = delete;
28+
KernelExecutionTimer& operator=(const KernelExecutionTimer&) = delete;
29+
KernelExecutionTimer(KernelExecutionTimer&&) = delete;
30+
KernelExecutionTimer& operator=(KernelExecutionTimer&&) = delete;
31+
32+
~KernelExecutionTimer() {
33+
C10_CUDA_CHECK(cudaEventDestroy(start_));
34+
C10_CUDA_CHECK(cudaEventDestroy(stop_));
35+
}
36+
37+
void start() {
38+
if (started_) {
39+
throw std::logic_error("Cannot call start() more than once.");
40+
}
41+
C10_CUDA_CHECK(cudaEventRecord(start_, stream_));
42+
started_ = true;
43+
}
44+
45+
void stop() {
46+
if (!started_) {
47+
throw std::logic_error("Must call start() before stop().");
48+
}
49+
if (stopped_) {
50+
throw std::logic_error("Cannot call stop() more than once.");
51+
}
52+
C10_CUDA_CHECK(cudaEventRecord(stop_, stream_));
53+
stopped_ = true;
54+
}
55+
56+
float elapsedMillis() const {
57+
if (!stopped_) {
58+
throw std::logic_error(
59+
"Must call stop() before retrieving elapsed time.");
60+
}
61+
float milliseconds = 0;
62+
C10_CUDA_CHECK(cudaEventSynchronize(stop_)); // Ensure timing is complete
63+
C10_CUDA_CHECK(cudaEventElapsedTime(&milliseconds, start_, stop_));
64+
return milliseconds;
65+
}
66+
67+
private:
68+
cudaEvent_t start_;
69+
cudaEvent_t stop_;
70+
const c10::cuda::CUDAStream stream_;
71+
bool started_ = false;
72+
bool stopped_ = false;
73+
};
74+
75+
} // namespace fbgemm_gpu::utils

fbgemm_gpu/include/fbgemm_gpu/utils/kernel_launcher.cuh

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
#include <c10/cuda/CUDAStream.h>
1414

1515
#include "fbgemm_gpu/utils/device_properties.cuh"
16+
#include "fbgemm_gpu/utils/kernel_execution_timer.cuh"
1617
#include "fbgemm_gpu/utils/source_context.h"
1718
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
1819

20+
#include <memory>
1921
#include <type_traits>
2022

2123
namespace fbgemm_gpu::utils {
@@ -91,7 +93,8 @@ decltype(auto) check_kernel_arg(const SourceContext& context, T&& arg) {
9193
template <
9294
bool EnableDSA = false,
9395
bool EnableBarrierIsolation = false,
94-
bool EnableNaNChecks = false>
96+
bool EnableNaNChecks = false,
97+
bool EnableExecutionTimer = false>
9598
struct KernelLauncher {
9699
const SourceContext context;
97100

@@ -263,17 +266,19 @@ struct KernelLauncher {
263266
}
264267

265268
template <typename KernelFunc, typename... Args>
266-
inline void launch_kernel(
269+
inline auto launch_kernel(
267270
const KernelFunc& kernel,
268271
const dim3 grid,
269272
const dim3 block,
270273
const size_t shared_mem_per_block,
271274
const c10::cuda::CUDAStream stream,
272-
Args&&... args) const {
275+
Args&&... args) const
276+
-> std::conditional_t<EnableExecutionTimer, float, void> {
273277
// Fetch device properties from the stream information
274278
const auto device = stream.device_index();
275279
const auto properties = *at::cuda::getDeviceProperties(device);
276280
const auto streamId = stream.id();
281+
[[maybe_unused]] std::unique_ptr<KernelExecutionTimer> timer = nullptr;
277282

278283
// Check that the grid sizes are within the range per the device associated
279284
// with the compute stream
@@ -305,6 +310,13 @@ struct KernelLauncher {
305310
cudaDeviceSynchronize();
306311
}
307312

313+
// If execution timer is enabled, initialize and start the CUDAEvents-based
314+
// timer prior to kernel launch
315+
if constexpr (EnableExecutionTimer) {
316+
timer = std::make_unique<KernelExecutionTimer>(stream);
317+
timer->start();
318+
}
319+
308320
if constexpr (EnableDSA) {
309321
// This launch code here is essentially the same as the contents of
310322
// TORCH_USE_CUDA_DSA macro, but with the addition of kernel argument
@@ -332,6 +344,11 @@ struct KernelLauncher {
332344
transform_kernel_arg(context, std::forward<Args>(args))...);
333345
}
334346

347+
// If execution timer is enabled, stop the CUDAEvents-based timer
348+
if constexpr (EnableExecutionTimer) {
349+
timer->stop();
350+
}
351+
335352
// If barrier isolation is enabled, synchronize the stream again to wait for
336353
// kernel execution to complete
337354
if constexpr (EnableBarrierIsolation) {
@@ -350,6 +367,11 @@ struct KernelLauncher {
350367
(check_kernel_arg(context.withSummary(summary), std::forward<Args>(args)),
351368
...);
352369
}
370+
371+
// If execution timer is enabled, return the elapsed time in milliseconds
372+
if constexpr (EnableExecutionTimer) {
373+
return timer->elapsedMillis();
374+
}
353375
}
354376
};
355377

@@ -420,3 +442,15 @@ struct KernelLauncher {
420442
location, #KERNEL, _FKL_TFILE_) \
421443
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
422444
}())
445+
446+
#define FBGEMM_TIME_KERNEL_RUN(KERNEL, GRID, BLOCK, SMEM, STREAM, ...) \
447+
([&] { \
448+
using source_location = fbgemm_gpu::utils::source_location; \
449+
constexpr auto location = source_location::current(); \
450+
decltype(KERNEL)& kernel = KERNEL; \
451+
\
452+
return fbgemm_gpu::utils:: \
453+
KernelLauncher<false, _FKL_BLOCKING_, _FKL_TENSORCHECK_, true>( \
454+
location, #KERNEL, _FKL_TFILE_) \
455+
.launch_kernel(kernel, GRID, BLOCK, SMEM, STREAM, __VA_ARGS__); \
456+
}())

0 commit comments

Comments
 (0)