Skip to content

Commit d132f22

Browse files
authored
HIP: add CDNA4 (gfx950) architecture support for MI350X/MI355X (ggml-org#21570)
Add AMD Instinct MI350X/MI355X (gfx950, CDNA4) support: - vendors/hip.h: Add CDNA4 preprocessor define for __gfx950__ - common.cuh: Add GGML_CUDA_CC_CDNA4 and GGML_CUDA_CC_IS_CDNA4 macros - mma.cuh: Route CDNA4 to compatible MFMA instructions: * f32 matmul: mfma_f32_16x16x4f32 (xf32 variant unavailable on gfx950) * bf16 matmul: mfma_f32_16x16x16bf16_1k (same as CDNA3) * int8 matmul: mfma_i32_16x16x32_i8/32x32x16 (same as CDNA3) - mmq.cuh: Include CDNA4 in stream-k kernel dispatch CDNA4 is largely compatible with CDNA3 except: - No xf32 MFMA (mfma_f32_16x16x8_xf32) — routes to f32 path - Different FP8 format (e4m3fn vs e4m3_fnuz) — not changed here Tested on AMD Instinct MI355X (gfx950), ROCm 7.0.1: - Build: compiles cleanly with -DAMDGPU_TARGETS=gfx950 - llama-bench (Qwen2.5-1.5B Q4_K_M, single GPU): * f16+FA: 40,013 tok/s prefill, 254 tok/s decode * q8_0+FA: functional - Flash attention: works correctly - MMQ: works correctly with stream-k dispatch Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
1 parent d6f3030 commit d132f22

4 files changed

Lines changed: 19 additions & 12 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
6868
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming
6969
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
70+
#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X
7071

7172
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
7273
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
@@ -87,7 +88,8 @@
8788
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
8889
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
8990
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
90-
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
91+
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4)
92+
#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1)
9193

9294
// Moore Threads
9395
#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons

ggml/src/ggml-cuda/mma.cuh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,8 @@ namespace ggml_cuda_mma {
10251025
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
10261026
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
10271027
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
1028-
#elif defined(CDNA2) || defined(CDNA1)
1028+
#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1)
1029+
// CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1
10291030
#pragma unroll
10301031
for (int i = 0; i < 2; ++i) {
10311032
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
@@ -1187,7 +1188,7 @@ namespace ggml_cuda_mma {
11871188
#elif defined(AMD_MFMA_AVAILABLE)
11881189
using floatx4_t = __attribute__((ext_vector_type(4))) float;
11891190
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
1190-
#if defined(CDNA3) || defined(CDNA2)
1191+
#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2)
11911192
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
11921193
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
11931194
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
@@ -1216,12 +1217,12 @@ namespace ggml_cuda_mma {
12161217
#if defined(AMD_MFMA_AVAILABLE)
12171218
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
12181219
int32x4_t * acc = (int32x4_t *) D.x;
1219-
#if defined(CDNA3)
1220+
#if defined(CDNA4) || defined(CDNA3)
12201221
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
12211222
((int64_t *) B.x)[0],
12221223
acc[0],
12231224
0, 0, 0);
1224-
#elif defined(CDNA2) || defined(CDNA)
1225+
#elif defined(CDNA2) || defined(CDNA1)
12251226
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
12261227
B.x[0],
12271228
acc[0],
@@ -1230,7 +1231,7 @@ namespace ggml_cuda_mma {
12301231
B.x[1],
12311232
acc[0],
12321233
0, 0, 0);
1233-
#endif // defined(CDNA3)
1234+
#endif // defined(CDNA4) || defined(CDNA3)
12341235

12351236
#elif defined(AMD_WMMA_AVAILABLE)
12361237

@@ -1295,12 +1296,12 @@ namespace ggml_cuda_mma {
12951296
#if defined(AMD_MFMA_AVAILABLE)
12961297
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
12971298
int32x16_t * acc = (int32x16_t *) D.x;
1298-
#if defined(CDNA3)
1299+
#if defined(CDNA4) || defined(CDNA3)
12991300
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
13001301
((int64_t *) B.x)[0],
13011302
acc[0],
13021303
0, 0, 0);
1303-
#elif defined(CDNA2) || defined(CDNA)
1304+
#elif defined(CDNA2) || defined(CDNA1)
13041305
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
13051306
B.x[0],
13061307
acc[0],
@@ -1309,7 +1310,7 @@ namespace ggml_cuda_mma {
13091310
B.x[1],
13101311
acc[0],
13111312
0, 0, 0);
1312-
#endif // defined(CDNA3)
1313+
#endif // defined(CDNA4) || defined(CDNA3)
13131314

13141315
#else
13151316
GGML_UNUSED_VARS(D, A, B);

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3645,7 +3645,7 @@ static __global__ void mul_mat_q(
36453645
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
36463646
return;
36473647
}
3648-
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
3648+
#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
36493649

36503650
constexpr int ITER_K = get_iter_k(type);
36513651

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@
189189
#define GCN
190190
#endif // defined(GCN5) || defined(GCN4)
191191

192+
#if defined(__gfx950__)
193+
#define CDNA4
194+
#endif // defined(__gfx950__)
195+
192196
#if defined(__gfx942__)
193197
#define CDNA3
194198
#endif // defined(__gfx942__)
@@ -201,9 +205,9 @@
201205
#define CDNA1
202206
#endif // defined(__gfx908__)
203207

204-
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
208+
#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
205209
#define CDNA // For the entire family
206-
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
210+
#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
207211

208212
#if defined(__GFX12__)
209213
#define RDNA4

0 commit comments

Comments
 (0)