44#include " xqa_loader.h"
55#include < cassert>
66
7- // Define global constants BEFORE including ANY header that uses them
8- #define HEAD_ELEMS 128
9- #define USE_PAGED_KV_CACHE 0
10- #define TOKENS_PER_PAGE 0
11- #define INPUT_FP16 1
12- #define ALLOW_MULTI_BLOCK_MODE 1
13-
14- #pragma nv_diag_suppress 177
15- #pragma nv_diag_suppress 20012
16-
17- // Include common headers once
18- #include " cuda_hint.cuh"
19- #include " mha.h"
20- // Include all helpers globally to ensure visibility
21- #include " ldgsts.cuh"
22- #include " mhaUtils.cuh"
23- #include " mha_components.cuh"
24- #include " mma.cuh"
25- #include " utils.cuh"
26- #include " hostUtils.h"
27-
28- // Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen
29- #undef HEAD_GRP_SIZE
30- #undef M_TILESIZE
31-
327namespace onnxruntime {
338namespace contrib {
349namespace cuda {
3510
36- // ============================================================================
37- // FP16 KV Cache Instantiations
38- // ============================================================================
39-
40- #define NAMESPACE_NAME grp1_fp16
41- #define GRP_SIZE 1
42- #define M_TILESIZE 8
43- #include " xqa_impl_gen.cuh"
44- #undef NAMESPACE_NAME
45- #undef GRP_SIZE
46- #undef M_TILESIZE
47-
48- #define NAMESPACE_NAME grp2_fp16
49- #define GRP_SIZE 2
50- #define M_TILESIZE 8
51- #include " xqa_impl_gen.cuh"
52- #undef NAMESPACE_NAME
53- #undef GRP_SIZE
54- #undef M_TILESIZE
55-
56- #define NAMESPACE_NAME grp4_fp16
57- #define GRP_SIZE 4
58- #define M_TILESIZE 8
59- #include " xqa_impl_gen.cuh"
60- #undef NAMESPACE_NAME
61- #undef GRP_SIZE
62- #undef M_TILESIZE
63-
64- #define NAMESPACE_NAME grp8_fp16
65- #define GRP_SIZE 8
66- #define M_TILESIZE 8
67- #include " xqa_impl_gen.cuh"
68- #undef NAMESPACE_NAME
69- #undef GRP_SIZE
70- #undef M_TILESIZE
71-
72- #define NAMESPACE_NAME grp16_fp16
73- #define GRP_SIZE 16
74- #define M_TILESIZE 16
75- #include " xqa_impl_gen.cuh"
76- #undef NAMESPACE_NAME
77- #undef GRP_SIZE
78- #undef M_TILESIZE
11+ // Forward declarations of instantiated kernels from H128 and H64 namespaces
12+ namespace H128 {
13+ template <typename T>
14+ Status LaunchXQAKernelImpl (
15+ const cudaDeviceProp& device_prop,
16+ cudaStream_t stream,
17+ const void * query,
18+ const void * key_cache,
19+ const void * value_cache,
20+ void * output,
21+ const int batch_size,
22+ const int num_heads,
23+ const int kv_num_heads,
24+ const int head_size,
25+ const int actual_seq_len,
26+ const int max_seq_len,
27+ const float scale,
28+ const bool is_bsnh,
29+ const int * seq_lens,
30+ const float * kv_cache_scale,
31+ const int kv_quant_type,
32+ void * workspace,
33+ size_t workspace_size);
7934
80- # define NAMESPACE_NAME grp32_fp16
81- # define GRP_SIZE 32
82- # define M_TILESIZE 32
83- # include " xqa_impl_gen.cuh "
84- # undef NAMESPACE_NAME
85- # undef GRP_SIZE
86- # undef M_TILESIZE
35+ size_t GetXQAScratchSize (
36+ const cudaDeviceProp& device_prop,
37+ int batch_size,
38+ int num_heads,
39+ int kv_num_heads,
40+ int max_seq_len);
41+ } // namespace H128
8742
88- // Extern declarations for INT8 kernels (implemented in xqa_loader_int8.cu)
89- Status LaunchXQAInt8Kernel (
43+ namespace H64 {
44+ template <typename T>
45+ Status LaunchXQAKernelImpl (
9046 const cudaDeviceProp& device_prop,
9147 cudaStream_t stream,
9248 const void * query,
@@ -103,13 +59,19 @@ Status LaunchXQAInt8Kernel(
10359 const bool is_bsnh,
10460 const int * seq_lens,
10561 const float * kv_cache_scale,
62+ const int kv_quant_type,
10663 void * workspace,
10764 size_t workspace_size);
10865
109- // ============================================================================
110- // Dispatcher
111- // ============================================================================
66+ size_t GetXQAScratchSize (
67+ const cudaDeviceProp& device_prop,
68+ int batch_size,
69+ int num_heads,
70+ int kv_num_heads,
71+ int max_seq_len);
72+ } // namespace H64
11273
74+ // Dispatcher Implementation
11375template <typename T>
11476Status LaunchXQAKernel (
11577 const cudaDeviceProp& device_prop,
@@ -131,36 +93,16 @@ Status LaunchXQAKernel(
13193 const int kv_quant_type,
13294 void * workspace,
13395 size_t workspace_size) {
134- if (head_size != 128 ) {
135- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " XQA only supports head_size=128." );
136- }
137-
138- // Dispatch to INT8 path if requested
139- if (kv_quant_type == 1 ) {
140- if constexpr (std::is_same<T, half>::value) {
141- return LaunchXQAInt8Kernel (device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
142- } else {
143- // BF16 case is handled in xqa_loader_bf16.cu via specialization
144- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " XQA INT8 path mismatch." );
145- }
146- }
147-
148- int group_size = num_heads / kv_num_heads;
149- switch (group_size) {
150- case 1 :
151- return grp1_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
152- case 2 :
153- return grp2_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
154- case 4 :
155- return grp4_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
156- case 8 :
157- return grp8_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
158- case 16 :
159- return grp16_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
160- case 32 :
161- return grp32_fp16::Launch<T>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, workspace, workspace_size);
162- default :
163- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " XQA supports group_size 1, 2, 4, 8, 16, 32. Input has " , group_size);
96+ if (head_size == 128 ) {
97+ return H128::LaunchXQAKernelImpl<T>(
98+ device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size,
99+ actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size);
100+ } else if (head_size == 64 ) {
101+ return H64::LaunchXQAKernelImpl<T>(
102+ device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size,
103+ actual_seq_len, max_seq_len, scale, is_bsnh, seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size);
104+ } else {
105+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " XQA only supports head_size=64 or 128. Input has " , head_size);
164106 }
165107}
166108
@@ -170,23 +112,29 @@ size_t GetXQAScratchSize(
170112 int num_heads,
171113 int kv_num_heads,
172114 int max_seq_len) {
173- int group_size = num_heads / kv_num_heads;
174- switch (group_size) {
175- case 1 :
176- return grp1_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
177- case 2 :
178- return grp2_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
179- case 4 :
180- return grp4_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
181- case 8 :
182- return grp8_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
183- case 16 :
184- return grp16_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
185- case 32 :
186- return grp32_fp16::GetScratchSize (device_prop, batch_size, kv_num_heads, max_seq_len);
187- default :
188- return 0 ; // Not supported
189- }
115+ // Just use H128 logic for scratch size estimation if it doesn't depend on head size being strictly 128 in estimation logic?
116+ // Looking at xqa_impl_gen.cuh, GetScratchSize depends on namespace/template params which depend on HEAD_ELEMS indirectly?
117+ // Actually, GetScratchSize in xqa_impl_gen calls `grpX_fp16::GetScratchSize`.
118+ // If H64 logic is different, we should pick the right one.
119+ // But GetXQAScratchSize doesn't take head_size as input?
120+ // Wait, the signature in xqa_loader.h DOES include head_size?
121+ // No, `size_t GetXQAScratchSize(const cudaDeviceProp& device_prop, int batch_size, int num_heads, int kv_num_heads, int max_seq_len);`
122+ // It does NOT have head_size.
123+
124+ // Checking `xqa_impl_gen.cuh`:
125+ // size_t scratch_size = ::onnxruntime::contrib::cuda::NAMESPACE_NAME::GetScratchSize(nbSeq, nbSubSeqPerSeq);
126+
127+ // `NAMESPACE_NAME` (e.g. grp8_fp16) is generated including `mha_impl.cuh`.
128+ // `mha_impl.cuh` depends on `HEAD_ELEMS`.
129+ // So the scratch size might depend on HEAD_ELEMS.
130+ // BUT the API doesn't pass head_size. This is a problem if scratch size depends on head size.
131+ // Most likely, scratch size depends on sequence lengths and number of heads, not head dim (unless smem usage constraint).
132+ // However, if I use H128's GetScratchSize, it assumes HEAD_ELEMS=128 for any persistent structures.
133+
134+ // Let's assume for now we use H128's size as a conservative estimate (usually larger head dim size -> maybe larger scratch? or same?).
135+ // If the kernels are built with static smem, 128 might need more.
136+
137+ return H128::GetXQAScratchSize (device_prop, batch_size, num_heads, kv_num_heads, max_seq_len);
190138}
191139
192140// Instantiate template for half
@@ -211,6 +159,28 @@ template Status LaunchXQAKernel<half>(
211159 void * workspace,
212160 size_t workspace_size);
213161
162+ // Instantiate template for BFloat16
163+ template Status LaunchXQAKernel<BFloat16>(
164+ const cudaDeviceProp& device_prop,
165+ cudaStream_t stream,
166+ const void * query,
167+ const void * key_cache,
168+ const void * value_cache,
169+ void * output,
170+ const int batch_size,
171+ const int num_heads,
172+ const int kv_num_heads,
173+ const int head_size,
174+ const int actual_seq_len,
175+ const int max_seq_len,
176+ const float scale,
177+ const bool is_bsnh,
178+ const int * seq_lens,
179+ const float * kv_cache_scale,
180+ const int kv_quant_type,
181+ void * workspace,
182+ size_t workspace_size);
183+
214184} // namespace cuda
215185} // namespace contrib
216186} // namespace onnxruntime
0 commit comments