[Question] Does jax.nn.dot_product_attention
invoke Flash Attention on TPUs?
#27014
Unanswered
shamuiscoding
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm curious if it does or does not.
Looking at the traces, I don't think it is, but I may be mistaken.
So you have to use
from jax.experimental.pallas.ops.tpu import flash_attention
to invoke flash attention on TPU, no matter what, right?this is the
vmap(BTNH,BSNH->BNTS) inside the dot_product_attention implementation.
%fusion.1247 = (f32[8,32,2688]{2,1,0:T(8,128)S(3)}, f32[8,32,2688,2688]{2,3,1,0:T(8,128)}) fusion(bf16[8,2688,32,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.1960, bf16[8,2688,32,64]{1,2,0,3:T(8,128)(2,1)} %fusion.1245, pred[2688,2688]{0,1:T(8,128)(4,1)S(3)} %custom-call.242), kind=kOutput, calls=%fused_computation.1103.clone.clone | %fusion.1247 = (f32[8,32,2688]{2,1,0:T(8,128)S(3)}, f32[8,32,2688,2688]{2,3,1,0:T(8,128)}) fusion(bf16[8,2688,32,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.1960, bf16[8,2688,32,64]{1,2,0,3:T(8,128)(2,1)} %fusion.1245, pred[2688,2688]{0,1:T(8,128)(4,1)S(3)} %custom-call.242), kind=kOutput, calls=%fused_computation.1103.clone.clone
Beta Was this translation helpful? Give feedback.
All reactions