@@ -891,7 +891,7 @@ TEST(AttentionTest, Attention3DGqaAttn) {
891891 q, k, v, std::vector<float>(), std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
892892 -1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
893893 y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
894- false, true , true // disable_cpu, disable_cuda, disable_dml
894+ false, false , true // disable_cpu, disable_cuda, disable_dml
895895 );
896896}
897897
@@ -926,7 +926,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) {
926926 q, k, v, m, std::initializer_list<bool>(), std::vector<float>(), std::vector<float>(),
927927 -1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
928928 y, std::vector<float>(), std::vector<float>(), std::vector<float>(),
929- false, true , true // disable_cpu, disable_cuda, disable_dml
929+ false, false , true // disable_cpu, disable_cuda, disable_dml
930930 );
931931}
932932
@@ -973,7 +973,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) {
973973 q, k, v, m, std::initializer_list<bool>(), past_key, past_value,
974974 -1, -1, std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type
975975 y, present_key, present_value, std::vector<float>(),
976- false, true , true // disable_cpu, disable_cuda, disable_dml
976+ false, false , true // disable_cpu, disable_cuda, disable_dml
977977 );
978978}
979979
0 commit comments