-
Notifications
You must be signed in to change notification settings - Fork 110
Open
Description
Implement flash decoding as described here: https://pytorch.org/blog/flash-decoding/
We have attention operators grouped like this:
Q -> [B, M, k]
K -> [B, k, N]
V -> [B, N, D]
S = dot(Q, K)
P = softmax(S)
O = dot(P, V) # [B, M, D]
To do flash decoding we will need to add another batch dimension for each group we want to split, and then do:
Q -> [B, G, M, k] # G is a broadcasted dimension
K -> [B, G, k, N/G]
V -> [B, G, N/G, D]
# first kernel
S = dot(Q, K)
P = softmax(S, axis=-1)
L = LSE(S) # [B, G, M, 1]
O' = dot(P, V) # [B, G, M, D]
# second kernel
scale = softmax(L, axis=1) # [B, G, M, 1]
R = mul(O', broadcast(scale)) # [B, G, M, D]
O = sum(R, axis=1) # [B, 1, M, D]
We will probably do this directly in the fuse_attention
pass after we have done the initial attention grouping.
Copilot