Skip to content

Commit 4ac7cfd

Browse files
committed
review feedback
1 parent 42d912f commit 4ac7cfd

File tree

4 files changed

+15
-16
lines changed

4 files changed

+15
-16
lines changed

onnxruntime/contrib_ops/cpu/utils/debug_macros.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
#define DUMP_TENSOR_D(...)
4747
#endif
4848

49-
#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG)
50-
#define DEBUG_PRINTF(fmt, ...) \
49+
#if (defined(__GNUC__) || defined(__clang__)) && (DUMP_TENSOR_LEVEL > 0)
50+
#define DUMP_PRINTF(fmt, ...) \
5151
std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__)
5252
#else
53-
#define DEBUG_PRINTF(fmt, ...) \
54-
do { \
53+
#define DUMP_PRINTF(fmt, ...) \
54+
do { \
5555
} while (0)
5656
#endif

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ namespace cuda {
2626

2727
namespace {
2828
// Map string attribute to quantization type enum
29-
KVQuantizationType StringToKVQuantizationType(const std::string& s) {
29+
KVQuantizationType StringToKVQuantizationType(std::string s) {
30+
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::toupper(c); });
3031
if (s == "NONE") {
3132
return KVQuantizationType::NONE;
3233
}

onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -554,13 +554,13 @@ Status LaunchGetSequenceLengths(
554554
}
555555

556556
// Trace function for debugging
557-
#define ORT_GQA_TRACE(func_name) \
558-
DEBUG_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \
559-
func_name, \
560-
static_cast<int>(parameters.is_packed_qkv), \
561-
static_cast<int>(parameters.is_first_prompt), \
562-
static_cast<int>(parameters.is_subsequent_prompt), \
563-
static_cast<int>(parameters.past_present_share_buffer));
557+
#define ORT_GQA_TRACE(func_name) \
558+
DUMP_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \
559+
func_name, \
560+
static_cast<int>(parameters.is_packed_qkv), \
561+
static_cast<int>(parameters.is_first_prompt), \
562+
static_cast<int>(parameters.is_subsequent_prompt), \
563+
static_cast<int>(parameters.past_present_share_buffer));
564564

565565
////////// Kernels (supports right padding but not left padding)
566566
// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path.
@@ -706,8 +706,8 @@ Status FlashDecoding(
706706

707707
bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
708708

709-
DEBUG_PRINTF("[FlashDecoding] key=%p, value=%p, present_key=%p, present_value=%p, seqlens_k=%p, is_packed_qkv=%d",
710-
key, value, present_key, present_value, seqlens_k, static_cast<int>(parameters.is_packed_qkv));
709+
DUMP_PRINTF("[FlashDecoding] key=%p, value=%p, present_key=%p, present_value=%p, seqlens_k=%p, is_packed_qkv=%d",
710+
key, value, present_key, present_value, seqlens_k, static_cast<int>(parameters.is_packed_qkv));
711711

712712
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
713713
device_prop, stream, query, present_key, present_value, key, value, data.output,

onnxruntime/contrib_ops/cuda/bert/xqa/barriers.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,7 @@ class MBarrier // rename this to MBarrier
356356
} else {
357357
float sleepDuration = 0.125F;
358358
while (!func()) {
359-
// if (sleepDuration > 1) {
360359
__nanosleep(uint32_t(sleepDuration));
361-
// }
362360
sleepDuration = sleepDuration * 1.25F + 0.F;
363361
}
364362
}

0 commit comments

Comments
 (0)