Skip to content

Commit 2e10874

Browse files
Copilottitaiwangms
andcommitted
Enable CUDA tests for GQA attention tests
Co-authored-by: titaiwangms <18010845+titaiwangms@users.noreply.github.com>
1 parent 0e7a632 commit 2e10874

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnxruntime/test/providers/cpu/llm/attention_op_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ TEST(AttentionTest, Attention3DGqaAttn) {
891891
q, k, v, std::vector<float>(), std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
892892
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
893893
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
894-
false, true, true // disable_cpu, disable_cuda, disable_dml
894+
false, false, true // disable_cpu, disable_cuda, disable_dml
895895
);
896896
}
897897

@@ -926,7 +926,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) {
926926
q, k, v, m, std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
927927
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
928928
y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
929-
false, true, true // disable_cpu, disable_cuda, disable_dml
929+
false, false, true // disable_cpu, disable_cuda, disable_dml
930930
);
931931
}
932932

@@ -973,7 +973,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {
973973
q, k, v, m, std::initializer_list<bool>(), past_key, past_value,
974974
-1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
975975
y, present_key, present_value, std::vector<float>(),
976-
false, true, true // disable_cpu, disable_cuda, disable_dml
976+
false, false, true // disable_cpu, disable_cuda, disable_dml
977977
);
978978
}
979979

0 commit comments

Comments
 (0)