Skip to content

Commit 0acfe30

Browse files
committed
add hdim 64 for xqa
1 parent 93d7c74 commit 0acfe30

16 files changed

+868
-150
lines changed

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,10 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
596596
k_quant_type_ == KVQuantizationType::PER_TENSOR &&
597597
v_quant_type_ == KVQuantizationType::PER_TENSOR &&
598598
data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor.
599-
parameters.head_size == 128 &&
599+
(parameters.head_size == 128 || parameters.head_size == 64) &&
600600
(group_size == 8 || group_size == 16 || group_size == 32));
601601
bool is_non_quantized_supported = !use_quantized_kv &&
602-
parameters.head_size == 128 &&
602+
(parameters.head_size == 128 || parameters.head_size == 64) &&
603603
(64 % group_size == 0);
604604

605605
data.use_xqa = !parameters.is_first_prompt &&

onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ inline size_t GetScratchSize(
2929
uint32_t nbSeq = static_cast<uint32_t>(batch_size * kv_num_heads);
3030

3131
size_t semaphore_size = nbSeq * sizeof(uint32_t);
32-
size_t scratch_size = ::onnxruntime::contrib::cuda::NAMESPACE_NAME::GetScratchSize(nbSeq, nbSubSeqPerSeq);
32+
size_t scratch_size = NAMESPACE_NAME::GetScratchSize(nbSeq, nbSubSeqPerSeq);
3333

3434
// Return total size with alignment padding
3535
return roundUp<size_t>(semaphore_size, 128) + scratch_size;

onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.cu

Lines changed: 97 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -4,89 +4,45 @@
44
#include "xqa_loader.h"
55
#include <cassert>
66

7-
// Define global constants BEFORE including ANY header that uses them
8-
#define HEAD_ELEMS 128
9-
#define USE_PAGED_KV_CACHE 0
10-
#define TOKENS_PER_PAGE 0
11-
#define INPUT_FP16 1
12-
#define ALLOW_MULTI_BLOCK_MODE 1
13-
14-
#pragma nv_diag_suppress 177
15-
#pragma nv_diag_suppress 20012
16-
17-
// Include common headers once
18-
#include "cuda_hint.cuh"
19-
#include "mha.h"
20-
// Include all helpers globally to ensure visibility
21-
#include "ldgsts.cuh"
22-
#include "mhaUtils.cuh"
23-
#include "mha_components.cuh"
24-
#include "mma.cuh"
25-
#include "utils.cuh"
26-
#include "hostUtils.h"
27-
28-
// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen
29-
#undef HEAD_GRP_SIZE
30-
#undef M_TILESIZE
31-
327
namespace onnxruntime {
338
namespace contrib {
349
namespace cuda {
3510

36-
// ============================================================================
37-
// FP16 KV Cache Instantiations
38-
// ============================================================================
39-
40-
#define NAMESPACE_NAME grp1_fp16
41-
#define GRP_SIZE 1
42-
#define M_TILESIZE 8
43-
#include "xqa_impl_gen.cuh"
44-
#undef NAMESPACE_NAME
45-
#undef GRP_SIZE
46-
#undef M_TILESIZE
47-
48-
#define NAMESPACE_NAME grp2_fp16
49-
#define GRP_SIZE 2
50-
#define M_TILESIZE 8
51-
#include "xqa_impl_gen.cuh"
52-
#undef NAMESPACE_NAME
53-
#undef GRP_SIZE
54-
#undef M_TILESIZE
55-
56-
#define NAMESPACE_NAME grp4_fp16
57-
#define GRP_SIZE 4
58-
#define M_TILESIZE 8
59-
#include "xqa_impl_gen.cuh"
60-
#undef NAMESPACE_NAME
61-
#undef GRP_SIZE
62-
#undef M_TILESIZE
63-
64-
#define NAMESPACE_NAME grp8_fp16
65-
#define GRP_SIZE 8
66-
#define M_TILESIZE 8
67-
#include "xqa_impl_gen.cuh"
68-
#undef NAMESPACE_NAME
69-
#undef GRP_SIZE
70-
#undef M_TILESIZE
71-
72-
#define NAMESPACE_NAME grp16_fp16
73-
#define GRP_SIZE 16
74-
#define M_TILESIZE 16
75-
#include "xqa_impl_gen.cuh"
76-
#undef NAMESPACE_NAME
77-
#undef GRP_SIZE
78-
#undef M_TILESIZE
11+
// Forward declarations of instantiated kernels from H128 and H64 namespaces
12+
namespace H128 {
13+
template <typename T>
14+
Status LaunchXQAKernelImpl(
15+
const cudaDeviceProp& device_prop,
16+
cudaStream_t stream,
17+
const void* query,
18+
const void* key_cache,
19+
const void* value_cache,
20+
void* output,
21+
const int batch_size,
22+
const int num_heads,
23+
const int kv_num_heads,
24+
const int head_size,
25+
const int actual_seq_len,
26+
const int max_seq_len,
27+
const float scale,
28+
const bool is_bsnh,
29+
const int* seq_lens,
30+
const float* kv_cache_scale,
31+
const int kv_quant_type,
32+
void* workspace,
33+
size_t workspace_size);
7934

80-
#define NAMESPACE_NAME grp32_fp16
81-
#define GRP_SIZE 32
82-
#define M_TILESIZE 32
83-
#include "xqa_impl_gen.cuh"
84-
#undef NAMESPACE_NAME
85-
#undef GRP_SIZE
86-
#undef M_TILESIZE
35+
size_t GetXQAScratchSize(
36+
const cudaDeviceProp& device_prop,
37+
int batch_size,
38+
int num_heads,
39+
int kv_num_heads,
40+
int max_seq_len);
41+
} // namespace H128
8742

88-
// Extern declarations for INT8 kernels (implemented in xqa_loader_int8.cu)
89-
Status LaunchXQAInt8Kernel(
43+
namespace H64 {
44+
template <typename T>
45+
Status LaunchXQAKernelImpl(
9046
const cudaDeviceProp& device_prop,
9147
cudaStream_t stream,
9248
const void* query,
@@ -103,13 +59,19 @@ Status LaunchXQAInt8Kernel(
10359
const bool is_bsnh,
10460
const int* seq_lens,
10561
const float* kv_cache_scale,
62+
const int kv_quant_type,
10663
void* workspace,
10764
size_t workspace_size);
10865

109-
// ============================================================================
110-
// Dispatcher
111-
// ============================================================================
66+
size_t GetXQAScratchSize(
67+
const cudaDeviceProp& device_prop,
68+
int batch_size,
69+
int num_heads,
70+
int kv_num_heads,
71+
int max_seq_len);
72+
} // namespace H64
11273

74+
// Dispatcher Implementation
11375
template <typename T>
11476
Status LaunchXQAKernel(
11577
const cudaDeviceProp& device_prop,
@@ -131,36 +93,16 @@ Status LaunchXQAKernel(
13193
const int kv_quant_type,
13294
void* workspace,
13395
size_t workspace_size) {
134-
if (head_size != 128) {
135-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=128.");
136-
}
137-
138-
// Dispatch to INT8 path if requested
139-
if (kv_quant_type == 1) {
140-
if constexpr (std::is_same<T, half>::value) {
141-
return LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
142-
} else {
143-
// BF16 case is handled in xqa_loader_bf16.cu via specialization
144-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 path mismatch.");
145-
}
146-
}
147-
148-
int group_size = num_heads / kv_num_heads;
149-
switch (group_size) {
150-
case 1:
151-
return grp1_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
152-
case 2:
153-
return grp2_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
154-
case 4:
155-
return grp4_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
156-
case 8:
157-
return grp8_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
158-
case 16:
159-
return grp16_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
160-
case 32:
161-
return grp32_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
162-
default:
163-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size);
96+
if (head_size == 128) {
97+
return H128::LaunchXQAKernelImpl<T>(
98+
device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size,
99+
actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size);
100+
} else if (head_size == 64) {
101+
return H64::LaunchXQAKernelImpl<T>(
102+
device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size,
103+
actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size);
104+
} else {
105+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64 or 128. Input has ", head_size);
164106
}
165107
}
166108

@@ -170,23 +112,29 @@ size_t GetXQAScratchSize(
170112
int num_heads,
171113
int kv_num_heads,
172114
int max_seq_len) {
173-
int group_size = num_heads / kv_num_heads;
174-
switch (group_size) {
175-
case 1:
176-
return grp1_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
177-
case 2:
178-
return grp2_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
179-
case 4:
180-
return grp4_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
181-
case 8:
182-
return grp8_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
183-
case 16:
184-
return grp16_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
185-
case 32:
186-
return grp32_fp16::GetScratchSize(device_prop, batch_size, kv_num_heads, max_seq_len);
187-
default:
188-
return 0; // Not supported
189-
}
115+
// Just use H128 logic for scratch size estimation if it doesn't depend on head size being strictly 128 in estimation logic?
116+
// Looking at xqa_impl_gen.cuh, GetScratchSize depends on namespace/template params which depend on HEAD_ELEMS indirectly?
117+
// Actually, GetScratchSize in xqa_impl_gen calls `grpX_fp16::GetScratchSize`.
118+
// If H64 logic is different, we should pick the right one.
119+
// But GetXQAScratchSize doesn't take head_size as input?
120+
// Wait, the signature in xqa_loader.h DOES include head_size?
121+
// No, `size_t GetXQAScratchSize(const cudaDeviceProp& device_prop, int batch_size, int num_heads, int kv_num_heads, int max_seq_len);`
122+
// It does NOT have head_size.
123+
124+
// Checking `xqa_impl_gen.cuh`:
125+
// size_t scratch_size = ::onnxruntime::contrib::cuda::NAMESPACE_NAME::GetScratchSize(nbSeq, nbSubSeqPerSeq);
126+
127+
// `NAMESPACE_NAME` (e.g. grp8_fp16) is generated including `mha_impl.cuh`.
128+
// `mha_impl.cuh` depends on `HEAD_ELEMS`.
129+
// So the scratch size might depend on HEAD_ELEMS.
130+
// BUT the API doesn't pass head_size. This is a problem if scratch size depends on head size.
131+
// Most likely, scratch size depends on sequence lengths and number of heads, not head dim (unless smem usage constraint).
132+
// However, if I use H128's GetScratchSize, it assumes HEAD_ELEMS=128 for any persistent structures.
133+
134+
// Let's assume for now we use H128's size as a conservative estimate (usually larger head dim size -> maybe larger scratch? or same?).
135+
// If the kernels are built with static smem, 128 might need more.
136+
137+
return H128::GetXQAScratchSize(device_prop, batch_size, num_heads, kv_num_heads, max_seq_len);
190138
}
191139

192140
// Instantiate template for half
@@ -211,6 +159,28 @@ template Status LaunchXQAKernel<half>(
211159
void* workspace,
212160
size_t workspace_size);
213161

162+
// Instantiate template for BFloat16
163+
template Status LaunchXQAKernel<BFloat16>(
164+
const cudaDeviceProp& device_prop,
165+
cudaStream_t stream,
166+
const void* query,
167+
const void* key_cache,
168+
const void* value_cache,
169+
void* output,
170+
const int batch_size,
171+
const int num_heads,
172+
const int kv_num_heads,
173+
const int head_size,
174+
const int actual_seq_len,
175+
const int max_seq_len,
176+
const float scale,
177+
const bool is_bsnh,
178+
const int* seq_lens,
179+
const float* kv_cache_scale,
180+
const int kv_quant_type,
181+
void* workspace,
182+
size_t workspace_size);
183+
214184
} // namespace cuda
215185
} // namespace contrib
216186
} // namespace onnxruntime
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define HEAD_ELEMS 128
5+
#define HEAD_DIM_NAMESPACE H128
6+
7+
#include "xqa_loader_impl.cuh"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace cuda {
12+
13+
// Explicit instantiation
14+
template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<half>(
15+
const cudaDeviceProp& device_prop,
16+
cudaStream_t stream,
17+
const void* query,
18+
const void* key_cache,
19+
const void* value_cache,
20+
void* output,
21+
const int batch_size,
22+
const int num_heads,
23+
const int kv_num_heads,
24+
const int head_size,
25+
const int actual_seq_len,
26+
const int max_seq_len,
27+
const float scale,
28+
const bool is_bsnh,
29+
const int* seq_lens,
30+
const float* kv_cache_scale,
31+
const int kv_quant_type,
32+
void* workspace,
33+
size_t workspace_size);
34+
35+
} // namespace cuda
36+
} // namespace contrib
37+
} // namespace onnxruntime
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define HEAD_ELEMS 64
5+
#define HEAD_DIM_NAMESPACE H64
6+
7+
#include "xqa_loader_impl.cuh"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace cuda {
12+
13+
// Explicit instantiation
14+
template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<half>(
15+
const cudaDeviceProp& device_prop,
16+
cudaStream_t stream,
17+
const void* query,
18+
const void* key_cache,
19+
const void* value_cache,
20+
void* output,
21+
const int batch_size,
22+
const int num_heads,
23+
const int kv_num_heads,
24+
const int head_size,
25+
const int actual_seq_len,
26+
const int max_seq_len,
27+
const float scale,
28+
const bool is_bsnh,
29+
const int* seq_lens,
30+
const float* kv_cache_scale,
31+
const int kv_quant_type,
32+
void* workspace,
33+
size_t workspace_size);
34+
35+
} // namespace cuda
36+
} // namespace contrib
37+
} // namespace onnxruntime
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define HEAD_ELEMS 128
5+
#define HEAD_DIM_NAMESPACE H128
6+
7+
#include "xqa_loader_bf16_impl.cuh"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#define HEAD_ELEMS 64
5+
#define HEAD_DIM_NAMESPACE H64
6+
7+
#include "xqa_loader_bf16_impl.cuh"

0 commit comments

Comments
 (0)