|
25 | 25 | #include <thread> |
26 | 26 |
|
27 | 27 | #include "gtest/gtest.h" |
| 28 | +#include "paddle/phi/api/include/context_pool.h" |
| 29 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
28 | 30 |
|
29 | 31 | #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) |
30 | 32 | namespace { |
@@ -167,25 +169,6 @@ TEST(StreamTest, QueryCudaStreamNotReadyReturnsFalse) { |
167 | 169 | EXPECT_NO_THROW(s.synchronize()); |
168 | 170 | } |
169 | 171 |
|
170 | | -TEST(StreamTest, QueryCudaStreamInvalidHandleThrows) { |
171 | | - if (!at::cuda::is_available()) { |
172 | | - return; |
173 | | - } |
174 | | - |
175 | | - auto device_index = c10::cuda::getCurrentCUDAStream().device_index(); |
176 | | -#ifdef PADDLE_WITH_HIP |
177 | | - hipStream_t raw_stream = nullptr; |
178 | | -#else |
179 | | - cudaStream_t raw_stream = nullptr; |
180 | | -#endif |
181 | | - ASSERT_NO_THROW(CreateRawStream(&raw_stream)); |
182 | | - |
183 | | - auto cuda_stream = c10::cuda::getStreamFromExternal(raw_stream, device_index); |
184 | | - ASSERT_NO_THROW(DestroyRawStream(raw_stream)); |
185 | | - |
186 | | - EXPECT_THROW(cuda_stream.query(), std::exception); |
187 | | - ClearLastStreamError(); |
188 | | -} |
189 | 172 | #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP |
190 | 173 |
|
191 | 174 | // ==================== synchronize ==================== |
@@ -255,30 +238,36 @@ TEST(CUDAStreamTest, GetStreamFromPoolBoolOverloadPreservesHighPriority) { |
255 | 238 | EXPECT_NE(high_priority, low_priority); |
256 | 239 | } |
257 | 240 |
|
258 | | -// After setCurrentCUDAStream redirects the per-thread current stream, |
| 241 | +// After setCurrentCUDAStream redirects the current stream, |
259 | 242 | // getDefaultCUDAStream must still return the null stream. |
260 | 243 | TEST(CUDAStreamTest, DefaultStreamUnaffectedBySetCurrentCUDAStream) { |
261 | 244 | if (!at::cuda::is_available()) { |
262 | 245 | return; |
263 | 246 | } |
264 | | - // Snapshot the per-thread current stream before we touch it so we can |
| 247 | + // Snapshot the current stream before we touch it so we can |
265 | 248 | // restore it afterward and avoid polluting subsequent tests. |
266 | 249 | auto original_stream = c10::cuda::getCurrentCUDAStream(); |
267 | 250 |
|
268 | 251 | // Obtain a non-default stream from the pool. |
269 | 252 | auto pool_stream = c10::cuda::getStreamFromPool(/*isHighPriority=*/false); |
270 | 253 |
|
271 | | - // Redirect the per-thread current stream. |
| 254 | + // Redirect the current stream. |
272 | 255 | c10::cuda::setCurrentCUDAStream(pool_stream); |
273 | 256 |
|
274 | 257 | auto default_stream = c10::cuda::getDefaultCUDAStream(); |
275 | 258 | auto current_stream = c10::cuda::getCurrentCUDAStream(); |
| 259 | + auto place = phi::GPUPlace(current_stream.device_index()); |
276 | 260 |
|
277 | 261 | // Default stream is still null; current stream has changed. |
278 | 262 | EXPECT_EQ(default_stream.id(), static_cast<c10::StreamId>(0)); |
279 | 263 | EXPECT_NE(default_stream, current_stream); |
| 264 | + EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(), |
| 265 | + current_stream.stream()); |
280 | 266 |
|
281 | | - // Restore the original per-thread current stream. |
| 267 | + // Restore the original current stream. |
282 | 268 | c10::cuda::setCurrentCUDAStream(original_stream); |
| 269 | + EXPECT_EQ(paddle::GetCurrentCUDAStream(place)->raw_stream(), |
| 270 | + original_stream.stream()); |
283 | 271 | } |
| 272 | + |
284 | 273 | #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP |
0 commit comments