-
Notifications
You must be signed in to change notification settings - Fork 6k
[Cpp API Compatibility] Adapt cuda native APIs to hip native APIs for DCU support #78595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c65fd3
ec24e81
d34c7d5
8a2b98b
96e93d4
61872e1
48e4361
b38c988
fd1d7ad
ac9d01e
4573a0d
9660018
31730cb
2a8a096
eb5317b
3d46977
6a64393
e38d958
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,11 +20,36 @@ | |
|
|
||
| #include <c10/core/ScalarType.h> | ||
|
|
||
| #if defined(PADDLE_WITH_HIP) | ||
| #include <hip/hip_runtime.h> | ||
| #include <hip/library_types.h> | ||
| #elif defined(PADDLE_WITH_CUDA) | ||
| #include <cuda.h> | ||
| #include <library_types.h> | ||
| #endif | ||
|
|
||
| namespace at::cuda { | ||
|
|
||
| #if defined(PADDLE_WITH_HIP) | ||
| using cudaDataType = hipDataType; | ||
| #define CUDA_R_16F HIP_R_16F | ||
| #define CUDA_R_32F HIP_R_32F | ||
| #define CUDA_R_64F HIP_R_64F | ||
| #define CUDA_C_16F HIP_C_16F | ||
| #define CUDA_C_32F HIP_C_32F | ||
| #define CUDA_C_64F HIP_C_64F | ||
| #define CUDA_R_8U HIP_R_8U | ||
| #define CUDA_R_8I HIP_R_8I | ||
| #define CUDA_R_32I HIP_R_32I | ||
| #define CUDA_R_16I HIP_R_16I | ||
| #define CUDA_R_64I HIP_R_64I | ||
| #define CUDA_R_16BF HIP_R_16BF | ||
| #define CUDA_R_8F_E4M3 HIP_R_8F_E4M3 | ||
| #define CUDA_R_8F_E5M2 HIP_R_8F_E5M2 | ||
| #elif defined(PADDLE_WITH_CUDA) | ||
| using cudaDataType = cudaDataType; | ||
| #endif | ||
|
|
||
| template <typename scalar_t> | ||
| cudaDataType getCudaDataType() { | ||
| static_assert(false && sizeof(scalar_t), | ||
|
|
@@ -110,17 +135,20 @@ inline cudaDataType ScalarTypeToCudaDataType( | |
| return CUDA_R_64I; | ||
| case c10::ScalarType::BFloat16: | ||
| return CUDA_R_16BF; | ||
| #if !defined(USE_ROCM) || ROCM_VERSION >= 60300 | ||
| #if defined(PADDLE_WITH_HIP) | ||
| case c10::ScalarType::Float8_e4m3fn: | ||
| return CUDA_R_8F_E4M3; | ||
| case c10::ScalarType::Float8_e5m2: | ||
| return CUDA_R_8F_E5M2; | ||
| #endif | ||
| #if defined(USE_ROCM) | ||
| case c10::ScalarType::Float8_e4m3fnuz: | ||
| return HIP_R_8F_E4M3_FNUZ; | ||
| case c10::ScalarType::Float8_e5m2fnuz: | ||
| return HIP_R_8F_E5M2_FNUZ; | ||
| #elif !defined(USE_ROCM) || ROCM_VERSION >= 60300 | ||
| case c10::ScalarType::Float8_e4m3fn: | ||
| return CUDA_R_8F_E4M3; | ||
| case c10::ScalarType::Float8_e5m2: | ||
| return CUDA_R_8F_E5M2; | ||
|
Comment on lines
+138
to
+151
|
||
| #endif | ||
| // #if (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) || | ||
| // (defined(USE_ROCM) && ROCM_VERSION >= 70000) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CUDA branch uses
using cudaDataType = cudaDataType;, which relies on unqualified lookup to find the global::cudaDataTypeand is easy to misread as a self-referential alias. Qualifying the RHS (e.g.,::cudaDataType) or removing the alias entirely on the CUDA path would make the intent clearer and avoid confusion for tools/readers.