-
Notifications
You must be signed in to change notification settings - Fork 6k
Expand file tree
/
Copy pathCUDAStream.cpp
More file actions
232 lines (202 loc) · 7.06 KB
/
CUDAStream.cpp
File metadata and controls
232 lines (202 loc) · 7.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <c10/cuda/CUDAStream.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <vector>
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/backends/gpu/gpu_info.h"
#endif
namespace c10::cuda {
namespace {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
constexpr int kStreamsPerPool = 32;
std::once_flag g_init_once;
c10::DeviceIndex g_num_gpus = -1;
struct DevicePools {
#ifdef PADDLE_WITH_HIP
std::vector<hipStream_t> low_priority;
std::vector<hipStream_t> high_priority;
#else
std::vector<cudaStream_t> low_priority;
std::vector<cudaStream_t> high_priority;
#endif
std::atomic<uint32_t> lp_counter{0};
std::atomic<uint32_t> hp_counter{0};
std::once_flag init_flag;
};
std::vector<std::unique_ptr<DevicePools>> g_pools;
#ifdef PADDLE_WITH_HIP
thread_local std::vector<hipStream_t> tls_current_streams;
#else
thread_local std::vector<cudaStream_t> tls_current_streams;
#endif
thread_local bool tls_streams_initialized = false;
void initGlobalState() {
std::call_once(g_init_once, []() {
g_num_gpus =
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetGPUDeviceCount());
g_pools.resize(g_num_gpus);
for (auto& ptr : g_pools) {
ptr = std::make_unique<DevicePools>();
}
});
}
void initDevicePools(c10::DeviceIndex device_index) {
phi::backends::gpu::GPUDeviceGuard guard(device_index);
int lo_pri = 0, hi_pri = 0;
#ifdef PADDLE_WITH_HIP
C10_CUDA_CHECK(hipDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
#else
C10_CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
#endif
auto& pool = *g_pools[device_index];
pool.low_priority.resize(kStreamsPerPool);
pool.high_priority.resize(kStreamsPerPool);
for (int i = 0; i < kStreamsPerPool; ++i) {
#ifdef PADDLE_WITH_HIP
C10_CUDA_CHECK(hipStreamCreateWithPriority(
&pool.low_priority[i], hipStreamNonBlocking, lo_pri));
C10_CUDA_CHECK(hipStreamCreateWithPriority(
&pool.high_priority[i], hipStreamNonBlocking, hi_pri));
#else
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&pool.low_priority[i], cudaStreamNonBlocking, lo_pri));
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&pool.high_priority[i], cudaStreamNonBlocking, hi_pri));
#endif
}
}
inline void check_gpu(c10::DeviceIndex device_index) {
TORCH_CHECK(device_index >= 0 && device_index < g_num_gpus,
"Device index value ",
static_cast<int>(device_index),
" is out of index range [0, ",
static_cast<int>(g_num_gpus),
")");
}
inline void initTLSCurrentStreams() {
if (!tls_streams_initialized) {
tls_current_streams.resize(g_num_gpus, nullptr);
tls_streams_initialized = true;
}
}
#endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} // namespace
inline CUDAStream make_cuda_stream(cudaStream_t raw,
c10::DeviceIndex device_index) {
c10::StreamId sid =
static_cast<c10::StreamId>(reinterpret_cast<intptr_t>(raw));
return CUDAStream(
c10::Stream(c10::Stream::UNSAFE,
c10::Device(c10::DeviceType::CUDA, device_index),
sid));
}
CUDAStream getStreamFromPool(const int priority,
c10::DeviceIndex device_index) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
if (device_index == -1) {
device_index =
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
}
check_gpu(device_index);
std::call_once(
g_pools[device_index]->init_flag, initDevicePools, device_index);
const uint32_t idx = (priority < 0 ? g_pools[device_index]->hp_counter++
: g_pools[device_index]->lp_counter++) %
kStreamsPerPool;
#ifdef PADDLE_WITH_HIP
hipStream_t raw = (priority < 0 ? g_pools[device_index]->high_priority[idx]
: g_pools[device_index]->low_priority[idx]);
#else
cudaStream_t raw = (priority < 0 ? g_pools[device_index]->high_priority[idx]
: g_pools[device_index]->low_priority[idx]);
#endif
return make_cuda_stream(raw, device_index);
#else
TORCH_CHECK(false, "getStreamFromPool is not supported without CUDA/HIP");
return getDefaultCUDAStream(device_index);
#endif
}
CUDAStream getStreamFromPool(const bool isHighPriority,
c10::DeviceIndex device_index) {
return getStreamFromPool(isHighPriority ? -1 : 0, device_index);
}
#ifdef PADDLE_WITH_HIP
CUDAStream getStreamFromExternal(hipStream_t ext_stream,
c10::DeviceIndex device_index) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
check_gpu(device_index);
#endif
return make_cuda_stream(ext_stream, device_index);
}
#else
CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
c10::DeviceIndex device_index) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
check_gpu(device_index);
#endif
return make_cuda_stream(ext_stream, device_index);
}
#endif
CUDAStream getDefaultCUDAStream(c10::DeviceIndex device_index) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
if (device_index == -1) {
device_index =
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
}
check_gpu(device_index);
#endif
return CUDAStream(c10::Stream(
c10::Stream::DEFAULT, c10::Device(c10::DeviceType::CUDA, device_index)));
}
CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
if (device_index == -1) {
device_index =
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
}
check_gpu(device_index);
initTLSCurrentStreams();
#ifdef PADDLE_WITH_HIP
hipStream_t raw = tls_current_streams[device_index];
#else
cudaStream_t raw = tls_current_streams[device_index];
#endif
if (raw == nullptr) {
return getDefaultCUDAStream(device_index);
}
return make_cuda_stream(raw, device_index);
#else
return getDefaultCUDAStream(device_index);
#endif
}
void setCurrentCUDAStream(CUDAStream stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
initGlobalState();
c10::DeviceIndex idx = stream.unwrap().device_index();
check_gpu(idx);
initTLSCurrentStreams();
tls_current_streams[idx] = stream.stream();
#else
(void)stream;
#endif
}
} // namespace c10::cuda