Skip to content

[Cpp API Compatibility] Adapt cuda native APIs to hip native APIs for DCU support#78595

Merged
SigureMo merged 18 commits intoPaddlePaddle:developfrom
youge325:cDCU
Apr 10, 2026
Merged

[Cpp API Compatibility] Adapt cuda native APIs to hip native APIs for DCU support#78595
SigureMo merged 18 commits intoPaddlePaddle:developfrom
youge325:cDCU

Conversation

@youge325
Copy link
Copy Markdown
Contributor

@youge325 youge325 commented Apr 7, 2026

PR Category

Execute Infrastructure

PR Types

New features

Description

新增对 DCU 设备的支持,等待 #78580 合入后再将单测纳入编译

是否引起精度变化

Copilot AI review requested due to automatic review settings April 7, 2026 03:45
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 7, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 7, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends Paddle’s PyTorch-compatible C++ “compat” layer to support DCU/ROCm builds by adapting CUDA-native APIs (streams, events, BLAS, etc.) to their HIP-native equivalents under PADDLE_WITH_HIP.

Changes:

  • Added HIP-aware stream/event interop in c10 compat APIs (e.g., hipStream_t / hipEvent_t conversions and priority queries).
  • Updated ATen compat APIs to accept HIP stream/event types (e.g., Tensor::record_stream(hipStream_t), CUDAEvent HIP support).
  • Added initial HIP branching for BLAS GEMM and CUDA data-type mapping.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
paddle/phi/api/include/compat/c10/cuda/CUDAStream.h Add hipStream_t conversions and HIP priority/range queries.
paddle/phi/api/include/compat/c10/cuda/CUDAStream.cpp Implement HIP stream pool creation, TLS current stream, and external-stream wrapping.
paddle/phi/api/include/compat/c10/cuda/CUDAException.h Add HIP include + HIP-aware C10_CUDA_CHECK.
paddle/phi/api/include/compat/c10/core/Stream.cpp Use HIP stream query/synchronize when PADDLE_WITH_HIP.
paddle/phi/api/include/compat/c10/core/Event.h Add HIP raw-stream record and HIP event accessor type.
paddle/phi/api/include/compat/ATen/ops/record_stream.h Add Tensor::record_stream(hipStream_t) overload under HIP.
paddle/phi/api/include/compat/ATen/cuda/CUDAEvent.h Add HIP event lifecycle/record/wait/sync/elapsed-time support.
paddle/phi/api/include/compat/ATen/cuda/CUDADataType.h Add HIP includes and HIP enum/value mappings for data types.
paddle/phi/api/include/compat/ATen/cuda/CUDAContextLight.h Adjust includes and CUDA/HIP-conditional type aliases and declarations.
paddle/phi/api/include/compat/ATen/cuda/CUDABlas.cpp Add HIP branch for GEMM intended to use HIP BLAS equivalents.
paddle/phi/api/include/compat/ATen/core/TensorBody.h Include HIP runtime headers when PADDLE_WITH_HIP.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +138 to +147
#if defined(PADDLE_WITH_HIP)
case c10::ScalarType::Float8_e4m3fnuz:
return HIP_R_8F_E4M3_FNUZ;
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#elif !defined(USE_ROCM) || ROCM_VERSION >= 60300
case c10::ScalarType::Float8_e4m3fn:
return CUDA_R_8F_E4M3;
case c10::ScalarType::Float8_e5m2:
return CUDA_R_8F_E5M2;
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ScalarTypeToCudaDataType(), the HIP branch now only handles Float8_e4m3fnuz/Float8_e5m2fnuz, but not Float8_e4m3fn/Float8_e5m2. These scalar types exist in c10::ScalarType, and on HIP builds this function will hit TORCH_INTERNAL_ASSERT(false, ...) for them. Consider restoring handling for Float8_e4m3fn/Float8_e5m2 in the HIP branch (mapping to HIP_R_8F_E4M3 / HIP_R_8F_E5M2) to avoid a runtime assert regression.

Copilot uses AI. Check for mistakes.
#define CUDA_R_8F_E4M3 HIP_R_8F_E4M3
#define CUDA_R_8F_E5M2 HIP_R_8F_E5M2
#elif defined(PADDLE_WITH_CUDA)
using cudaDataType = cudaDataType;
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA branch uses using cudaDataType = cudaDataType;, which relies on unqualified lookup to find the global ::cudaDataType and is easy to misread as a self-referential alias. Qualifying the RHS (e.g., ::cudaDataType) or removing the alias entirely on the CUDA path would make the intent clearer and avoid confusion for tools/readers.

Suggested change
using cudaDataType = cudaDataType;
using cudaDataType = ::cudaDataType;

Copilot uses AI. Check for mistakes.
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 7, 2026

/re-run all-failed

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 7, 2026

dcu-build 的日志是不是有点问题,没跑成功的话日志就会被删掉看不了

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 7, 2026

欸,居然有paddle_test,会解决 DCU build 中 WITH_SHARED_PHI=OFF 导致测试文件无法链接的问题吗

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 7, 2026

没法了 ,还是要改 CMakeList.txt ,这次在 WITH_SHARED_PHI=OFF 分支补了对 phi 的链接

@SigureMo
Copy link
Copy Markdown
Member

SigureMo commented Apr 7, 2026

欸,居然有paddle_test,会解决 DCU build 中 WITH_SHARED_PHI=OFF 导致测试文件无法链接的问题吗

我记得 paddle_test 有问题才改成 cc_test#75012

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体方向我认同,HIP stream / event / context 的适配也基本沿着现有 compat 层的抽象在做;不过这轮我这边还不能过,当前有两个阻塞项:

  1. paddle/phi/api/include/compat/ATen/cuda/CUDADataType.h 在 HIP 分支里把 Float8_e4m3fn / Float8_e5m2 的映射漏掉了,这两个枚举现在会直接落到 default,属于明确的兼容性回退。
  2. test/cpp/compat/CMakeLists.txt 现在直接放开了 ROCm/DCU 下整套 compat tests,但 required 的 Linux-DCU / Test 已经因为这些用例在无可见设备时直接调用 HIP 接口而失败。这里要么先不要整体开启 WITH_ROCM,要么先把依赖真实设备的 case 补上 runtime guard / GTEST_SKIP() 之后再放开。

另外我看了一下 PR-CI-SOT / Build and Test,这条更像是 Python 3.14 环境里安装 numpy>=2.3.5 失败,暂时没看到和这次改动有直接关系。

@@ -1,4 +1,4 @@
if(NOT WIN32 AND NOT WITH_ROCM)
if(NOT WIN32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里去掉 NOT WITH_ROCM 之后,会把整套 compat C++ tests 都放进 ROCm/DCU CI,包括现有的 ATen_CUDAContext_testc10_Event_testc10_Stream_test 等。

当前 required 的 Linux-DCU / Test 已经因为这些用例在无可见设备时直接调用 HIP 接口而失败,日志里可以看到多处 Hip error(100), no ROCm-capable device is detected。如果这轮只是先补 compat 接口适配,我建议先不要整体放开 WITH_ROCM;至少要先给这些依赖真实设备的 case 补 runtime guard / GTEST_SKIP(),再启用到 DCU CI。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯?这个 PR 的主要目标就是测 DCU 呀,也不是补 compat 接口呀

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,我前一条里的前提表述不准确:既然这个 PR 的目标就是把 DCU 测试带起来,那这里去掉 NOT WITH_ROCM 的方向本身没有问题。

我现在卡住的点是这批用例在当前 required 的 Linux-DCU / Test 上还没有跑通。我看了这条任务的日志,失败集中在刚放开的 compat tests,例如:

  • CUDAContextLightTest.IsAvailable
  • EventTest.CudaEventLazyCreateAndRecord
  • CUDAStreamTest.DefaultStreamIsNullStream

报错都是同一类:Hip error(100), no ROCm-capable device is detected

所以我这边想表达的是:目标可以就是测 DCU,但合入前至少要把这条链路打通——要么这些 case 根据运行环境补好 runtime guard / GTEST_SKIP(),要么确认并修正这条 DCU 任务的设备环境,让它能真正执行这些测试。当前这个状态我先按阻塞看。

return CUDA_R_8F_E5M2;
#endif
#if defined(USE_ROCM)
#if defined(PADDLE_WITH_HIP)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 HIP 分支改成只处理 Float8_e4m3fnuz / Float8_e5m2fnuz,但 c10::ScalarType 里的 Float8_e4m3fn / Float8_e5m2 之前也会在这里完成映射。现在在 HIP 编译下这两个枚举会直接落到 default,触发 TORCH_INTERNAL_ASSERT(false, ...),相当于引入了一次兼容性回退。

建议把这两个 case 在 HIP 分支里也保留下来,继续映射到 HIP_R_8F_E4M3 / HIP_R_8F_E5M2

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 7, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@2c4a369). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #78595   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         1           
  Lines              ?         8           
  Branches           ?         0           
===========================================
  Hits               ?         8           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 8, 2026

/re-run all-failed

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充一个当前 head 上的阻塞点:这轮 fix DCU-test 虽然补回了 Float8 映射,也开始给 DCU 相关用例加 guard,但新引入的 HasVisibleCUDADevice() 仍然依赖 c10::cuda::device_count()。按 compat 层当前实现,这条路径在“无可见 ROCm 设备”场景下本身就可能先抛异常,所以还需要确认这次修复是否真的覆盖了 Linux-DCU / Test 的失败路径。

我先继续等这轮 CI 结果;如果 Linux-DCU / Test 过了,这个点也可以一起重新核对。


namespace {

bool HasVisibleCUDADevice() { return c10::cuda::device_count() > 0; }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 guard 现在还不够稳。

HasVisibleCUDADevice() 调的是 c10::cuda::device_count(),但 compat 层当前实现里这个函数会直接转发到 phi::backends::gpu::GetGPUDeviceCount()paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.cpp:23-26)。上一轮 Linux-DCU / TestEventTest.DeviceCount / StreamTest.DeviceCount 报的 Hip error(100), no ROCm-capable device is detected,就是这条路径出来的。

也就是说,在“没有可见设备”的场景下,这个 helper 很可能还没来得及返回 false 就已经先抛异常了;同文件里的 CUDAContextLightTest.IsAvailable 现在也还是直接走这条路径。

建议这里改成不会抛异常的探测方式,或者至少在 helper 里 catch 后返回 false,否则这轮 fix DCU-test 可能还是盖不住原来的 DCU test 失败。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dcu跑测试的时候检测不到设备会直接报错,这里和cuda_info对齐,不报错,只打印一条日志

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥dcu test会在没有dcu的机器上跑

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

刚看到你这两条,先回一下。

我这里提这个点,不是预设 “DCU test 就应该在无 DCU 机器上跑”,而是上一轮 required 的 Linux-DCU / Test 日志里,实际就走到了这条无可见设备的路径,并且报的是 no ROCm-capable device is detected。作为 reviewer,我只能先按当时 CI 真正跑出来的路径来卡这个问题。

你这次最新提交在 paddle/phi/backends/gpu/rocm/rocm_info.cc:70-74hipGetDeviceCount 失败从直接报错改成 VLOG + count=0,这点我认同,方向上也和你前面说的“和 cuda_info 对齐”一致。也就是说,我之前在这个线程里提的 device-count 异常风险,代码上已经被你这次提交正面处理了

我这边先不再纠结“为什么那条 CI 会落到无卡场景”,等这轮 Linux-DCU / Test 真正跑到 test 阶段、确认不再因为这条路径报错后,我会再更新这条 review 线程的结论。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了所有测试在 DCU 上都能跑,现在对所有 cuda 相关测试都加了 at::cuda::is_available() 进行防护

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,我刚补看到这条。

如果现在已经把这批 CUDA 相关测试统一改成用 at::cuda::is_available() 做防护,那从思路上看就比我前面担心的 device_count() 路径更对题了,覆盖范围也更完整。这样处理的话,我在这个线程里关注的点就基本收敛成一件事:确认这轮 CI,尤其是 Linux-DCU / Test,是否真的不再因为“无可见设备”路径报错

前面我这边跟进慢了一点,主要是一直在等这轮 DCU 任务真正跑到 test 结果,避免只看代码就提前下结论。现在我先把这条理解同步到线程里,后面等 CI 结果出来再决定要不要更新 review 结论。

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 8, 2026

/re-run all-failed

Copy link
Copy Markdown
Contributor

@ShigureNyako ShigureNyako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

我重新看了这轮最新 CI,之前我卡住的两个阻塞点现在都已经收敛了:

  • Linux-DCU / Test 最终为 SUCCESS,日志里此前关注的 no ROCm-capable device is detected 路径没有再作为失败原因出现。
  • 中间确实出现过 test_nn_init_function 首轮 timeout,但 rerun 后通过,最终这条 DCU job 的结论是成功。
  • HIP 下 Float8_e4m3fn / Float8_e5m2 的映射也已经补回。

所以我这边把之前的 changes requested 调整为通过;当前剩下的 Check approval 失败本身就是在等 reviewer approval,不是代码问题。

@ShigureNyako
Copy link
Copy Markdown
Contributor

补充一下:我刚才 approve 的时候 Coverage test 还没有出最终结果,现在这条 CI 已经结束,失败原因也比较明确。

  • Coverage test 里的 Determine whether the coverage rate reaches 90% 步骤报的是:expected >= 90.0 %, actual 75.0 %
  • Codecov 评论里对应到 paddle/phi/api/include/compat/c10/core/Stream.cpp,当前还有 2 行 diff 没被覆盖
  • 这条 job 里的单测失败都已经在 rerun 后转绿,所以现在剩下的是测试覆盖率 gate,而不是前面 DCU 兼容路径本身又回退了

也就是说,我前面关注的功能/兼容性阻塞点已经收敛,但当前 PR 还需要把这块覆盖补齐,CI 才算真正过完。

@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 9, 2026

/re-run all-failed

1 similar comment
@youge325
Copy link
Copy Markdown
Contributor Author

youge325 commented Apr 9, 2026

/re-run all-failed

Copy link
Copy Markdown
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

明天可以催一下我找其他同学 review,可能会忘掉

Comment on lines +70 to +75
status = hipGetDeviceCount(&count);
if (status != hipSuccess) {
VLOG(2) << "You have gpu driver and rocm installed, but the machine does "
"not have any visible gpu card.";
count = 0;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zrr1999 CUDA 侧 #55335 早就改了,这里看起来是和 CUDA 对齐而已

Comment on lines +251 to +259
else()
# For static phi builds (e.g., DCU), link phi static libraries directly
target_link_libraries(${test_name} phi)
if(WITH_GPU OR WITH_ROCM)
target_link_libraries(${test_name} -Wl,--start-group phi_core phi_gpu
-Wl,--end-group)
else()
target_link_libraries(${test_name} phi_core)
endif()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@risemeup1 review 时候麻烦顺带看看这里会不会有什么影响

@SigureMo SigureMo merged commit eafa80d into PaddlePaddle:develop Apr 10, 2026
169 of 177 checks passed
@youge325 youge325 deleted the cDCU branch April 10, 2026 03:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants