Skip to content

Commit 213a82d

Browse files
Copilottitaiwangms
andcommitted
Add float template instantiation for GQA QkvToContext
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
1 parent b86acbd commit 213a82d

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,20 @@ Status QkvToContext(
13261326
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet.");
13271327
}
13281328

1329+
template struct GroupQueryAttentionData<float>;
1330+
1331+
template Status QkvToContext<float>(
1332+
const cudaDeviceProp& device_prop,
1333+
cublasHandle_t& cublas,
1334+
Stream* ort_stream,
1335+
contrib::GroupQueryAttentionParameters& parameters,
1336+
GroupQueryAttentionData<float>& data);
1337+
1338+
template Status LaunchUnpackQKV<float, LAYOUT_BNSH>(
1339+
const float* packed_qkv, float* unpacked_q, float* unpacked_k, float* unpacked_v, const int num_heads,
1340+
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
1341+
cudaStream_t stream, const int max_threads_per_block);
1342+
13291343
template struct GroupQueryAttentionData<half>;
13301344

13311345
template Status QkvToContext<half>(

0 commit comments

Comments
 (0)