Skip to content

Commit 30968e9

Browse files
authored
[CUDA] Fix build errors (#27319)
Fix build errors that not caught by CI, but show up in post-merge builds. It was caused by two commits: 5a9877a and 9adf238 One commit changed GQA interface, while the other commit uses old GQA interface.
1 parent a9ebcbd commit 30968e9

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,8 @@ Status ExtremeDecoding(
612612
batch_size,
613613
parameters.seqlen_present_kv_cache, // max_seqlen (capacity)
614614
data.past_seq_lens,
615-
data.cos_cache,
616-
data.sin_cache,
615+
reinterpret_cast<const CudaT*>(data.cos_cache),
616+
reinterpret_cast<const CudaT*>(data.sin_cache),
617617
parameters.do_rotary ? parameters.rotary_dim : 0,
618618
data.position_ids,
619619
parameters.rotary_interleaved,
@@ -1105,6 +1105,7 @@ Status QkvToContext(
11051105

11061106
template struct GroupQueryAttentionData<half, half>;
11071107
template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>;
1108+
template struct GroupQueryAttentionData<BFloat16, BFloat16>;
11081109
template struct GroupQueryAttentionData<half, int8_t>;
11091110

11101111
template Status QkvToContext<half, half>(
@@ -1121,6 +1122,13 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>(
11211122
contrib::GroupQueryAttentionParameters& parameters,
11221123
GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data);
11231124

1125+
template Status QkvToContext<BFloat16, BFloat16>(
1126+
const cudaDeviceProp& device_prop,
1127+
cublasHandle_t& cublas,
1128+
Stream* ort_stream,
1129+
contrib::GroupQueryAttentionParameters& parameters,
1130+
GroupQueryAttentionData<BFloat16, BFloat16>& data);
1131+
11241132
template Status QkvToContext<half, int8_t>(
11251133
const cudaDeviceProp& device_prop,
11261134
cublasHandle_t& cublas,

onnxruntime/core/providers/cuda/llm/attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
196196
gqa_parameters.num_splits = 1;
197197

198198
// Construct GroupQueryAttentionData
199-
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT> gqa_data;
199+
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT, CudaT> gqa_data;
200200

201201
// Scratch buffers for flash/memory efficient attention
202202
IAllocatorUniquePtr<void> k_buffer;
@@ -355,6 +355,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
355355
// Centralized scratch buffer allocation using GQABufferRequirements
356356
auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute<T>(
357357
gqa_parameters,
358+
false, // use_xqa
358359
gqa_data.use_flash_attention,
359360
gqa_data.use_flash_attention_fast_decode,
360361
gqa_data.use_memory_efficient_attention);
@@ -478,7 +479,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
478479
// Call GQA kernel (with flash or memory efficient attention)
479480
cublasHandle_t cublas = GetCublasHandle(context);
480481

481-
return onnxruntime::contrib::cuda::QkvToContext<CudaT>(
482+
return onnxruntime::contrib::cuda::QkvToContext<CudaT, CudaT>(
482483
device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data);
483484
}
484485

0 commit comments

Comments
 (0)