Skip to content

Commit 990ba5f

Browse files
authored
Fix GQA Parity (#27108)
Fix [#27079](#27079) - Qwen3 model quality regression on CUDA backend. ### Root Cause Analysis The parity issue was caused by **buffer pointer misconfiguration** in the GQA (Group Query Attention) QKV preprocessing pipeline. The original implementation used multiple separate kernels for: 1. Unpacking packed QKV tensor 2. Applying RoPE (Rotary Position Embedding) to Q and K 3. Appending K/V to cache This multi-kernel approach created opportunities for misconfiguration: - Buffers were allocated but not properly used - Pointers could reference memory that was not yet allocated or initialized - Buffer sharing logic was fragmented across different code paths ### Solution Consolidate QKV preprocessing into a **single fused kernel** (`UnpackRoPEAppend`) that performs all operations in one pass: 1. **Unified kernel design**: A single kernel handles unpacking, RoPE application, and cache append operations 2. **Simplified buffer management**: The new `PrepareQKV` function clearly manages buffer allocation and ensures proper initialization 3. **Explicit past-to-present cache copy**: When `past_present_share_buffer` is false, explicitly copy past KV cache to present buffer before appending new tokens 4. **Zero-initialization for non-shared buffers**: Clear present KV buffers when not sharing with past to ensure deterministic output ### Changes Summary | File | Changes | |------|---------| | [group_query_attention_qkv.cuh](cci:7://file:///home/tlwu/onnxruntime/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh:0:0-0:0) | New fused `UnpackRoPEAppend` kernel with shared memory optimization for non-interleaved RoPE | | `group_query_attention_impl.cu` | New `PrepareQKV` helper function that orchestrates buffer setup and kernel launch | | `group_query_attention.cc` | Simplified operator logic by delegating QKV prep to unified helper | | `test_gqa.py` | Enhanced test coverage for various QKV configurations | ### Key Improvements - **Reduced kernel launches**: From 4-5 separate kernel calls to a single fused kernel - **Better memory safety**: All buffer pointers are validated in a single location - **Improved RoPE handling**: Uses shared memory for efficient non-interleaved RoPE computation - **Deterministic output**: Explicit buffer initialization ensures consistent results across runs - **Compatible with quantized KV cache**: The new preprocessing kernel design supports future quantization work ### Testing - All existing GQA unit tests pass - Verified Qwen3 model no longer produces gibberish output - Tested both fp16/bf16 and various head configurations
1 parent 879ec03 commit 990ba5f

File tree

12 files changed

+835
-1015
lines changed

12 files changed

+835
-1015
lines changed

onnxruntime/contrib_ops/cpu/bert/attention_parameters.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ struct AttentionParameters {
2525
int num_splits; // number of splits for splitkv
2626
int rotary_dim = 0; // rotary embedding dimension
2727
int beam_width;
28-
bool is_unidirectional;
29-
bool past_present_share_buffer;
28+
bool is_unidirectional = false;
29+
bool past_present_share_buffer = false;
3030
bool is_packed_qkv = false; // whether qkv is packed
31-
bool do_rotary;
32-
bool broadcast_attn_bias_dim_0;
33-
bool broadcast_attn_bias_dim_1;
31+
bool do_rotary = false;
32+
bool broadcast_attn_bias_dim_0 = false;
33+
bool broadcast_attn_bias_dim_1 = false;
3434
float mask_filter_value;
3535
float scale;
36-
bool use_tf32;
36+
bool use_tf32 = false;
3737
bool is_output_bnsh = false; // whether the output format is BNSH
3838
AttentionMaskType mask_type;
3939
AttentionQkvFormat qkv_format;
@@ -88,9 +88,8 @@ struct GroupQueryAttentionParameters : AttentionParameters {
8888
int seqlen_past_kv_cache; // sequence length of past kv tensor
8989
int seqlen_present_kv_cache; // sequence length of present kv tensor
9090
int local_window_size; // Mask out tokens prior to total_sequence_length - local_window_size
91-
bool kv_share_buffer;
92-
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
93-
bool is_first_prompt; // indicates whether this is first decoding step
91+
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
92+
bool is_first_prompt; // indicates whether this is first decoding step
9493
bool rotary_interleaved;
9594
bool use_smooth_softmax;
9695
float softcap;

onnxruntime/contrib_ops/cpu/utils/debug_macros.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
#pragma once
2+
#include <cstdio>
23
#include "core/common/make_string.h"
34

4-
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
5-
6-
#ifdef DEBUG_GENERATION
7-
#define DUMP_TENSOR_LEVEL 2
8-
#else
9-
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
5+
#if !defined(DUMP_TENSOR_LEVEL)
6+
#define DUMP_TENSOR_LEVEL 0
107
#endif
118

129
#define DUMP_CPU_TENSOR_LEVEL DUMP_TENSOR_LEVEL
@@ -48,3 +45,12 @@
4845
#else
4946
#define DUMP_TENSOR_D(...)
5047
#endif
48+
49+
#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG)
50+
#define DEBUG_PRINTF(fmt, ...) \
51+
std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
52+
#else
53+
#define DEBUG_PRINTF(fmt, ...) \
54+
do { \
55+
} while (0)
56+
#endif

onnxruntime/contrib_ops/cuda/bert/attention_data.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -179,35 +179,20 @@ struct GroupQueryAttentionData {
179179

180180
// Memory Efficient buffers
181181
T* fmha_buffer = nullptr;
182-
T* unpacked_qkv_buffer = nullptr;
183-
T* rotary_buffer = nullptr;
184-
int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs
182+
T* qkv_buffer = nullptr;
183+
185184
T* k = nullptr;
186185
T* v = nullptr;
187186

188-
#ifndef NDEBUG
189-
// Buffer size tracking for debug validation
190-
// Allocated sizes are set during buffer allocation in group_query_attention.cc
191-
// Max used sizes are updated during kernel calls in group_query_attention_impl.cu
192-
// Validated before operator returns to ensure usage exactly matches allocation
193-
size_t unpacked_qkv_buffer_size = 0; // Allocated size
194-
size_t rotary_buffer_size = 0; // Allocated size
195-
size_t position_ids_buffer_size = 0; // Allocated size
196-
mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels)
197-
mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels)
198-
mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels)
199-
#endif
200-
201187
// Output Tensors
202188
T* output = nullptr;
203-
T* present_key = nullptr;
204-
T* present_value = nullptr;
189+
void* present_key = nullptr;
190+
void* present_value = nullptr;
205191

206192
// Kernel Flags
207193
bool use_flash_attention = false;
208194
bool use_memory_efficient_attention = false;
209195
bool use_flash_attention_fast_decode = false;
210-
bool disable_fused_kv = false;
211196
};
212197

213198
template <typename T>

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
533533
params.is_seqlens_k_cumulative = seqlens_k_ == nullptr;
534534
if (seqlens_k_ != nullptr) {
535535
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
536-
params.seqused_k = static_cast<int*>(seqlens_k_);
537536
}
538537

539538
if (rotary_cos != nullptr) {

onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
132132

133133
size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads);
134134
size_t get_softmax_lse_size(size_t token_count, size_t num_heads);
135+
size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q);
136+
size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads,
137+
size_t seqlen_q, size_t head_size_rounded);
135138

136-
std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads,
137-
size_t head_size, size_t num_SMs);
139+
std::tuple<size_t, size_t, size_t> get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k,
140+
size_t num_heads, size_t head_size, size_t num_SMs);
138141

139142
bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k);
140143

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include <vector>
5+
#include <algorithm>
56
#include "core/providers/cuda/cuda_common.h"
67
#include "core/platform/env_var_utils.h"
78
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
@@ -39,8 +40,17 @@ REGISTER_KERNEL_TYPED(MLFloat16)
3940
REGISTER_KERNEL_TYPED(BFloat16)
4041

4142
constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE";
42-
constexpr const char* kDisableFusedKv = "ORT_DISABLE_FUSED_KV";
4343

44+
// Group Query Attention (GQA) Operator
45+
//
46+
// This operator implements Group Query Attention, a variation of Multi-Head Attention (MHA)
47+
// where the number of key/value heads is smaller than the number of query heads.
48+
// It supports:
49+
// - Rotary Positional Embeddings (RoPE)
50+
// - KV Cache (past/present key/value)
51+
// - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData
52+
// - Flash Attention and Memory Efficient Attention backends
53+
//
4454
template <typename T>
4555
GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
4656
: CudaKernel(info) {
@@ -63,7 +73,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
6373

6474
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
6575

66-
// Memory efficient attention supports float and float16. BFloat16 support is added for SM80+ via cutlass kernels.
76+
// Memory efficient attention supports float and float16. BFloat16 support added for SM80+.
6777
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
6878

6979
if (!disable_flash_attention_) {
@@ -72,9 +82,23 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
7282
}
7383

7484
disable_flash_decode_ = ParseEnvironmentVariableWithDefault<bool>(kDisableFlashDecode, false);
75-
disable_fused_kv_ = ParseEnvironmentVariableWithDefault<bool>(kDisableFusedKv, false);
7685
}
7786

87+
// ComputeInternal executes the GQA kernel.
88+
//
89+
// Inputs:
90+
// 0. query (Tensor) [batch, sequence_length, hidden_size]
91+
// 1. key (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
92+
// 2. value (Tensor) [batch, sequence_length, kv_hidden_size] (Optional)
93+
// 3. past_key (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
94+
// 4. past_value (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional)
95+
// 5. seqlens_k (Tensor) [batch] - Total sequence length minus 1 (for historical compatibility)
96+
// 6. total_seqlen (Tensor) - Max total sequence length
97+
// 7. cos_cache (Tensor) - Precomputed cosine table for RoPE
98+
// 8. sin_cache (Tensor) - Precomputed sine table for RoPE
99+
// 9. position_ids (Tensor) - Position indices for RoPE
100+
// 10. attention_bias (Tensor) - Not supported in this kernel
101+
// 11. head_sink (Tensor) - Attention sink for GPT-OSS
78102
template <typename T>
79103
Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
80104
const Tensor* query = context->Input<Tensor>(0);
@@ -162,7 +186,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
162186
IAllocatorUniquePtr<void> k_buffer;
163187
IAllocatorUniquePtr<void> v_buffer;
164188
IAllocatorUniquePtr<void> rotary_buffer;
165-
IAllocatorUniquePtr<void> position_ids_buffer;
166189
IAllocatorUniquePtr<void> fmha_buffer;
167190
IAllocatorUniquePtr<void> unpacked_qkv_buffer;
168191
IAllocatorUniquePtr<int> seq_lens_buffer;
@@ -185,24 +208,39 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
185208
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
186209
data.present_value = reinterpret_cast<CudaT*>(context->Output<Tensor>(2)->MutableData<T>());
187210

211+
// Compute past_present_share_buffer early since it's needed for flash attention path selection.
212+
// This compares the final pointer values after quantization handling.
213+
parameters.past_present_share_buffer = (data.past_key == data.present_key);
214+
188215
#if USE_FLASH_ATTENTION
189216
bool use_flash_attention = !disable_flash_attention_ &&
190217
onnxruntime::flash::is_supported<T>(device_prop,
191218
parameters.head_size,
192219
parameters.num_heads,
193220
parameters.kv_num_heads);
194-
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer;
195-
if (use_flash_attention) {
196-
data.use_flash_attention = true;
197-
data.use_memory_efficient_attention = false;
198221

222+
data.use_flash_attention = use_flash_attention;
223+
data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer;
224+
225+
if (use_flash_attention) {
199226
// Allocate Flash specific buffers (Softmax LSE, Accum)
200227
size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads);
228+
229+
int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads;
201230
auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
202-
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
231+
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, num_heads_for_split,
203232
parameters.head_size, device_prop.multiProcessorCount);
233+
204234
parameters.num_splits = static_cast<int>(num_splits);
205235

236+
if (data.use_flash_attention_fast_decode && num_splits > 1) {
237+
// The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel.
238+
// However, the LSE and Accum buffers must store results for ALL num_heads.
239+
softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length);
240+
auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; };
241+
out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32));
242+
}
243+
206244
softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
207245
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
208246
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
@@ -214,11 +252,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
214252
#endif
215253

216254
if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1) {
217-
// FlashAttentionDecoding Fast Path:
255+
// FlashDecoding Fast Path:
218256
// - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed.
219-
// - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to:
220-
// 1. Determine where to append new K/V in the cache
221-
// 2. Apply correct causal masking (attention only to positions [0, past_seq_len])
222257
// - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1.
223258
// - This optimization avoids launching GetSequenceLengths kernel for single-token decoding.
224259
data.past_seq_lens = const_cast<int*>(total_seq_lens_minus_one->Data<int>());
@@ -239,16 +274,20 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
239274
parameters.is_first_prompt,
240275
cuda_stream,
241276
device_prop.maxThreadsPerBlock));
277+
DUMP_TENSOR_INIT();
278+
DUMP_TENSOR("total_seq_lens", data.total_seq_lens, parameters.batch_size, 1);
279+
DUMP_TENSOR("past_seq_lens", data.past_seq_lens, parameters.batch_size, 1);
280+
DUMP_TENSOR("padded_seq_lens", data.padded_seq_lens, parameters.batch_size, 1);
242281
}
243282

244-
if (!use_flash_attention) {
245-
// Fall back to memory efficient attention.
246283
#if USE_MEMORY_EFFICIENT_ATTENTION
284+
if (!data.use_flash_attention) {
285+
// Fall back to memory efficient attention.
247286
int sm = (device_prop.major * 10) + device_prop.minor;
248287
bool use_memory_efficient_attention =
249-
!use_flash_attention &&
250288
!disable_memory_efficient_attention_ &&
251289
has_memory_efficient_attention(sm, std::is_same<T, MLFloat16>::value, std::is_same<T, BFloat16>::value, parameters.head_size, parameters.head_size);
290+
data.use_memory_efficient_attention = use_memory_efficient_attention;
252291

253292
// KV buffer for head expansion (when num_heads != kv_num_heads)
254293
size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads))
@@ -262,49 +301,30 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
262301
k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
263302
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
264303
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
265-
#else
266-
constexpr bool use_memory_efficient_attention = false;
267-
#endif
268-
269-
data.use_memory_efficient_attention = use_memory_efficient_attention;
270-
data.use_flash_attention = false;
271304

272305
data.k = reinterpret_cast<CudaT*>(k_buffer.get());
273306
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
274307
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
275-
data.disable_fused_kv = disable_fused_kv_;
276308
}
309+
#endif
277310

311+
// -------------
278312
// Centralized scratch buffer allocation using GQABufferRequirements
279313
// This ensures allocation logic stays in sync with kernel usage
280314
auto buffer_req = GQABufferRequirements::Compute<T>(
281315
parameters,
282-
use_flash_attention,
316+
data.use_flash_attention,
283317
data.use_flash_attention_fast_decode,
284318
data.use_memory_efficient_attention);
285319

286-
if (buffer_req.unpacked_qkv_bytes > 0) {
287-
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.unpacked_qkv_bytes, context->GetComputeStream());
288-
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
289-
}
290-
if (buffer_req.rotary_buffer_bytes > 0) {
291-
rotary_buffer = GetScratchBuffer<void>(buffer_req.rotary_buffer_bytes, context->GetComputeStream());
292-
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
320+
if (buffer_req.qkv_buffer_bytes > 0) {
321+
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
322+
data.qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
293323
}
294-
if (buffer_req.position_ids_bytes > 0) {
295-
position_ids_buffer = GetScratchBuffer<void>(buffer_req.position_ids_bytes, context->GetComputeStream());
296-
data.position_ids_buffer = reinterpret_cast<int64_t*>(position_ids_buffer.get());
297-
}
298-
#ifndef NDEBUG
299-
// Track allocated sizes for validation
300-
data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes;
301-
data.rotary_buffer_size = buffer_req.rotary_buffer_bytes;
302-
data.position_ids_buffer_size = buffer_req.position_ids_bytes;
303-
#endif
304324

305325
if (kernel_options_->AllowDebugInfo()) {
306326
AttentionKernelDebugInfo debug_info;
307-
debug_info.use_flash_attention = use_flash_attention;
327+
debug_info.use_flash_attention = data.use_flash_attention;
308328
debug_info.use_efficient_attention = data.use_memory_efficient_attention;
309329

310330
debug_info.Print("GroupQueryAttention",
@@ -313,12 +333,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
313333
std::is_same<T, BFloat16>::value);
314334
}
315335

316-
if (data.past_key == data.present_key) {
317-
parameters.kv_share_buffer = true;
318-
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when kv_share_buffer is true");
336+
// Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup)
337+
if (parameters.past_present_share_buffer) {
338+
ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true");
319339
} else {
320-
parameters.kv_share_buffer = false;
321-
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when kv_share_buffer is false");
340+
ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false");
322341
}
323342

324343
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
@@ -337,19 +356,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
337356
ORT_RETURN_IF_ERROR(QkvToContext<CudaT>(
338357
device_prop, cublas, context->GetComputeStream(), parameters, data));
339358

340-
#ifndef NDEBUG
341-
// Validate buffer usage matches allocation exactly
342-
ORT_ENFORCE(data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size,
343-
"unpacked_qkv_buffer: used ", data.unpacked_qkv_max_used,
344-
" bytes but allocated ", data.unpacked_qkv_buffer_size);
345-
ORT_ENFORCE(data.rotary_max_used == data.rotary_buffer_size,
346-
"rotary_buffer: used ", data.rotary_max_used,
347-
" bytes but allocated ", data.rotary_buffer_size);
348-
ORT_ENFORCE(data.position_ids_max_used == data.position_ids_buffer_size,
349-
"position_ids_buffer: used ", data.position_ids_max_used,
350-
" bytes but allocated ", data.position_ids_buffer_size);
351-
#endif
352-
353359
return Status::OK();
354360
}
355361

onnxruntime/contrib_ops/cuda/bert/group_query_attention.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class GroupQueryAttention final : public CudaKernel {
3535
bool disable_flash_attention_;
3636
bool disable_memory_efficient_attention_;
3737
bool disable_flash_decode_;
38-
bool disable_fused_kv_;
3938

4039
static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
4140
IAllocatorUniquePtr<int> zeros_;

0 commit comments

Comments
 (0)