Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,8 @@ Status ExtremeDecoding(
batch_size,
parameters.seqlen_present_kv_cache, // max_seqlen (capacity)
data.past_seq_lens,
data.cos_cache,
data.sin_cache,
reinterpret_cast<const CudaT*>(data.cos_cache),
reinterpret_cast<const CudaT*>(data.sin_cache),
parameters.do_rotary ? parameters.rotary_dim : 0,
data.position_ids,
parameters.rotary_interleaved,
Expand Down Expand Up @@ -1105,6 +1105,7 @@ Status QkvToContext(

template struct GroupQueryAttentionData<half, half>;
template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>;
template struct GroupQueryAttentionData<BFloat16, BFloat16>;
template struct GroupQueryAttentionData<half, int8_t>;

template Status QkvToContext<half, half>(
Expand All @@ -1121,6 +1122,13 @@ template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data);

template Status QkvToContext<BFloat16, BFloat16>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
Stream* ort_stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<BFloat16, BFloat16>& data);

template Status QkvToContext<half, int8_t>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
gqa_parameters.num_splits = 1;

// Construct GroupQueryAttentionData
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT> gqa_data;
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT, CudaT> gqa_data;

// Scratch buffers for flash/memory efficient attention
IAllocatorUniquePtr<void> k_buffer;
Expand Down Expand Up @@ -355,6 +355,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Centralized scratch buffer allocation using GQABufferRequirements
auto buffer_req = onnxruntime::contrib::cuda::GQABufferRequirements::Compute<T>(
gqa_parameters,
false, // use_xqa
gqa_data.use_flash_attention,
gqa_data.use_flash_attention_fast_decode,
gqa_data.use_memory_efficient_attention);
Expand Down Expand Up @@ -478,7 +479,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Call GQA kernel (with flash or memory efficient attention)
cublasHandle_t cublas = GetCublasHandle(context);

return onnxruntime::contrib::cuda::QkvToContext<CudaT>(
return onnxruntime::contrib::cuda::QkvToContext<CudaT, CudaT>(
device_prop, cublas, context->GetComputeStream(), gqa_parameters, gqa_data);
}

Expand Down
Loading