Skip to content

Implement flash decoding #4334

@pfultz2

Description

@pfultz2

Implement flash decoding as described here: https://pytorch.org/blog/flash-decoding/

We have attention operators grouped like this:

Q -> [B, M, k]
K -> [B, k, N]
V -> [B, N, D]

S = dot(Q, K)
P = softmax(S)
O = dot(P, V) # [B, M, D]

To do flash decoding we will need to add another batch dimension for each group we want to split, and then do:

Q -> [B, G, M, k] # G is a broadcasted dimension
K -> [B, G, k, N/G]
V -> [B, G, N/G, D]

# first kernel
S = dot(Q, K)
P = softmax(S, axis=-1)
L = LSE(S) # [B, G, M, 1]
O' = dot(P, V) # [B, G, M, D]

# second kernel
scale = softmax(L, axis=1) # [B, G, M, 1]
R = mul(O', broadcast(scale)) # [B, G, M, D]
O = sum(R, axis=1) # [B, 1, M, D]

We will probably do this directly in the fuse_attention pass after we have done the initial attention grouping.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions