Skip to content

Commit 53c333f

Browse files
Copilottitaiwangms
andcommitted
Add debug tracking fields and num_splits parameter
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
1 parent f2007a2 commit 53c333f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

onnxruntime/core/providers/cuda/llm/attention.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
225225
gqa_parameters.local_window_size = -1; // No local window for standard attention
226226
gqa_parameters.zeros_count = 0;
227227
gqa_parameters.zero_ptr = nullptr;
228+
gqa_parameters.num_splits = 1; // No splits for unfused path
228229

229230
// Construct GroupQueryAttentionData
230231
onnxruntime::contrib::cuda::GroupQueryAttentionData<CudaT> gqa_data;
@@ -278,6 +279,16 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
278279
gqa_data.k = nullptr;
279280
gqa_data.v = nullptr;
280281

282+
#ifndef NDEBUG
283+
// Initialize debug tracking fields
284+
gqa_data.unpacked_qkv_buffer_size = 0;
285+
gqa_data.rotary_buffer_size = 0;
286+
gqa_data.position_ids_buffer_size = 0;
287+
gqa_data.unpacked_qkv_max_used = 0;
288+
gqa_data.rotary_max_used = 0;
289+
gqa_data.position_ids_max_used = 0;
290+
#endif
291+
281292
// Call GQA kernel
282293
auto& device_prop = GetDeviceProp();
283294
cublasHandle_t cublas = GetCublasHandle(context);

0 commit comments

Comments
 (0)