Skip to content

Commit 2a52fc1

Browse files
committed
lintrunner
1 parent d28d92d commit 2a52fc1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

onnxruntime/test/python/transformers/test_mha_flash_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,9 @@ def parity_check_mha(
371371
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
372372
out = out.detach().cpu().numpy()
373373
# Pytorch to compare
374-
out_ref, _ = attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False)
374+
out_ref, _ = attention_ref(
375+
q, k, v, query_padding_mask=None, key_padding_mask=None, attention_bias=None, causal=False
376+
)
375377
out_ref = out_ref.detach().cpu().numpy()
376378

377379
numpy.testing.assert_allclose(

0 commit comments

Comments
 (0)