[Question] Does jax.nn.dot_product_attention
invoke Flash Attention on TPUs?
#27014
Unanswered
shamuiscoding
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm curious if it does or does not.
Looking at the traces, I don't think it is, but I may be mistaken.
So you have to use
from jax.experimental.pallas.ops.tpu import flash_attention
to invoke flash attention on TPU, no matter what, right?this is the
vmap(BTNH,BSNH->BNTS) inside the dot_product_attention implementation.
%fusion.1247 = (f32[8,32,2688]{2,1,0:T(8,128)S(3)}, f32[8,32,2688,2688]{2,3,1,0:T(8,128)}) fusion(bf16[8,2688,32,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.1960, bf16[8,2688,32,64]{1,2,0,3:T(8,128)(2,1)} %fusion.1245, pred[2688,2688]{0,1:T(8,128)(4,1)S(3)} %custom-call.242), kind=kOutput, calls=%fused_computation.1103.clone.clone | %fusion.1247 = (f32[8,32,2688]{2,1,0:T(8,128)S(3)}, f32[8,32,2688,2688]{2,3,1,0:T(8,128)}) fusion(bf16[8,2688,32,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.1960, bf16[8,2688,32,64]{1,2,0,3:T(8,128)(2,1)} %fusion.1245, pred[2688,2688]{0,1:T(8,128)(4,1)S(3)} %custom-call.242), kind=kOutput, calls=%fused_computation.1103.clone.clone
Beta Was this translation helpful? Give feedback.
All reactions