Skip to content

Commit

Permalink
Fix dk/dv autograd error on TPU flash attention (#8685)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmelumian972 authored Feb 7, 2025
1 parent afaf0d0 commit 0cd1fc2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def fa_custom_backward(
if require_grad_ab:
grad_ab = grads[1]

if require_grad_k or require_grad_k:
if require_grad_k or require_grad_v:
payload, _ = trace_pallas(
_flash_attention_bwd_dkv,
q,
Expand Down

0 comments on commit 0cd1fc2

Please sign in to comment.