Skip to content

Commit 453a3be

Browse files
authored
[Cpp API Compatibility] Sync c10 CUDA stream state with Paddle's GPUContext stream (#78652)
1 parent 0442435 commit 453a3be

File tree

3 files changed

+37
-46
lines changed

3 files changed

+37
-46
lines changed

paddle/phi/api/include/compat/c10/cuda/CUDAStream.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <mutex>
2020
#include <vector>
2121

22+
#include "paddle/phi/api/include/context_pool.h"
23+
#include "paddle/phi/backends/gpu/gpu_context.h"
2224
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
2325
#include "paddle/phi/backends/gpu/gpu_info.h"
2426
#endif
@@ -49,13 +51,6 @@ struct DevicePools {
4951

5052
std::vector<std::unique_ptr<DevicePools>> g_pools;
5153

52-
#ifdef PADDLE_WITH_HIP
53-
thread_local std::vector<hipStream_t> tls_current_streams;
54-
#else
55-
thread_local std::vector<cudaStream_t> tls_current_streams;
56-
#endif
57-
thread_local bool tls_streams_initialized = false;
58-
5954
void initGlobalState() {
6055
std::call_once(g_init_once, []() {
6156
g_num_gpus =
@@ -104,12 +99,25 @@ inline void check_gpu(c10::DeviceIndex device_index) {
10499
")");
105100
}
106101

107-
inline void initTLSCurrentStreams() {
108-
if (!tls_streams_initialized) {
109-
tls_current_streams.resize(g_num_gpus, nullptr);
110-
tls_streams_initialized = true;
111-
}
102+
inline phi::GPUContext* getMutableGPUContext(c10::DeviceIndex device_index) {
103+
return static_cast<phi::GPUContext*>(
104+
paddle::experimental::DeviceContextPool::Instance().GetMutable(
105+
phi::GPUPlace(device_index)));
106+
}
107+
108+
#ifdef PADDLE_WITH_HIP
109+
inline hipStream_t getPaddleCurrentStream(c10::DeviceIndex device_index) {
110+
auto* current_stream =
111+
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index));
112+
return current_stream == nullptr ? nullptr : current_stream->raw_stream();
113+
}
114+
#else
115+
inline cudaStream_t getPaddleCurrentStream(c10::DeviceIndex device_index) {
116+
auto* current_stream =
117+
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index));
118+
return current_stream == nullptr ? nullptr : current_stream->raw_stream();
112119
}
120+
#endif
113121

114122
#endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
115123

@@ -192,12 +200,7 @@ CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index) {
192200
static_cast<c10::DeviceIndex>(phi::backends::gpu::GetCurrentDeviceId());
193201
}
194202
check_gpu(device_index);
195-
initTLSCurrentStreams();
196-
#ifdef PADDLE_WITH_HIP
197-
hipStream_t raw = tls_current_streams[device_index];
198-
#else
199-
cudaStream_t raw = tls_current_streams[device_index];
200-
#endif
203+
auto raw = getPaddleCurrentStream(device_index);
201204
if (raw == nullptr) {
202205
return getDefaultCUDAStream(device_index);
203206
}
@@ -212,8 +215,7 @@ void setCurrentCUDAStream(CUDAStream stream) {
212215
initGlobalState();
213216
c10::DeviceIndex idx = stream.unwrap().device_index();
214217
check_gpu(idx);
215-
initTLSCurrentStreams();
216-
tls_current_streams[idx] = stream.stream();
218+
getMutableGPUContext(idx)->SetStream(stream.stream());
217219
#else
218220
(void)stream;
219221
#endif

paddle/phi/api/include/compat/c10/cuda/CUDAStream.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ CUDAStream getStreamFromExternal(cudaStream_t ext_stream,
191191
#endif
192192

193193
/**
194-
* Set the current CUDA stream for the device of the given stream in the
195-
* calling thread.
194+
* Set the current CUDA stream for the device of the given stream.
196195
*
197-
* Implements per-thread, per-device current stream semantics.
196+
* Keeps the compat c10 stream state aligned with Paddle's GPUContext so
197+
* Paddle stream guards and c10 callers observe the same current stream.
198198
*/
199199
void setCurrentCUDAStream(CUDAStream stream);
200200

test/cpp/compat/c10_Stream_test.cc

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <thread>
2626

2727
#include "gtest/gtest.h"
28+
#include "paddle/phi/api/include/context_pool.h"
29+
#include "paddle/phi/backends/gpu/gpu_context.h"
2830

2931
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
3032
namespace {
@@ -167,25 +169,6 @@ TEST(StreamTest, QueryCudaStreamNotReadyReturnsFalse) {
167169
EXPECT_NO_THROW(s.synchronize());
168170
}
169171

170-
TEST(StreamTest, QueryCudaStreamInvalidHandleThrows) {
171-
if (!at::cuda::is_available()) {
172-
return;
173-
}
174-
175-
auto device_index = c10::cuda::getCurrentCUDAStream().device_index();
176-
#ifdef PADDLE_WITH_HIP
177-
hipStream_t raw_stream = nullptr;
178-
#else
179-
cudaStream_t raw_stream = nullptr;
180-
#endif
181-
ASSERT_NO_THROW(CreateRawStream(&raw_stream));
182-
183-
auto cuda_stream = c10::cuda::getStreamFromExternal(raw_stream, device_index);
184-
ASSERT_NO_THROW(DestroyRawStream(raw_stream));
185-
186-
EXPECT_THROW(cuda_stream.query(), std::exception);
187-
ClearLastStreamError();
188-
}
189172
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
190173

191174
// ==================== synchronize ====================
@@ -255,30 +238,36 @@ TEST(CUDAStreamTest, GetStreamFromPoolBoolOverloadPreservesHighPriority) {
255238
EXPECT_NE(high_priority, low_priority);
256239
}
257240

258-
// After setCurrentCUDAStream redirects the per-thread current stream,
241+
// After setCurrentCUDAStream redirects the current stream,
259242
// getDefaultCUDAStream must still return the null stream.
260243
TEST(CUDAStreamTest, DefaultStreamUnaffectedBySetCurrentCUDAStream) {
261244
if (!at::cuda::is_available()) {
262245
return;
263246
}
264-
// Snapshot the per-thread current stream before we touch it so we can
247+
// Snapshot the current stream before we touch it so we can
265248
// restore it afterward and avoid polluting subsequent tests.
266249
auto original_stream = c10::cuda::getCurrentCUDAStream();
267250

268251
// Obtain a non-default stream from the pool.
269252
auto pool_stream = c10::cuda::getStreamFromPool(/*isHighPriority=*/false);
270253

271-
// Redirect the per-thread current stream.
254+
// Redirect the current stream.
272255
c10::cuda::setCurrentCUDAStream(pool_stream);
273256

274257
auto default_stream = c10::cuda::getDefaultCUDAStream();
275258
auto current_stream = c10::cuda::getCurrentCUDAStream();
259+
auto place = phi::GPUPlace(current_stream.device_index());
276260

277261
// Default stream is still null; current stream has changed.
278262
EXPECT_EQ(default_stream.id(), static_cast<c10::StreamId>(0));
279263
EXPECT_NE(default_stream, current_stream);
264+
EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(),
265+
current_stream.stream());
280266

281-
// Restore the original per-thread current stream.
267+
// Restore the original current stream.
282268
c10::cuda::setCurrentCUDAStream(original_stream);
269+
EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(),
270+
original_stream.stream());
283271
}
272+
284273
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP

0 commit comments

Comments
 (0)