Skip to content

Commit d43cbca

Browse files
committed
Expanded vmap support for flash_mha. Vmapping q but not k,v reduces to a grouped-query attention, which we now support.
1 parent 4367317 commit d43cbca

File tree

2 files changed

+182
-25
lines changed

2 files changed

+182
-25
lines changed

Diff for: src/flash_attn_jax/flash.py

+84-20
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from jax.sharding import PositionalSharding
2020

2121
from einops import rearrange
22+
import einops
2223
import math
2324

2425
from .flash_sharding import _flash_mha_fwd_hlo_sharded, _flash_mha_bwd_hlo_sharded
@@ -104,28 +105,91 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
104105
# ==== VMap rules ====
105106

106107
def mha_fwd_batch(vector_arg_values, batch_axes, **kwargs):
107-
assert tuple(batch_axes) == (0,0,0), "Only support vmapping mha over axis 0 for now,"
108-
[q, k, v] = vector_arg_values
109-
[b, n, l, h, d] = q.shape
110-
[b, n, lk, hk, d] = k.shape
111-
assert [b, n, lk, hk, d] == list(v.shape)
112-
out, lse = _flash_mha_fwd_p.bind(q.reshape([b*n,l,h,d]),
113-
k.reshape([b*n,lk,hk,d]),
114-
v.reshape([b*n,lk,hk,d]),
115-
**kwargs)
116-
return (out.reshape([b,n,*out.shape[1:]]), lse.reshape([b,n,*lse.shape[1:]])), (0,0)
108+
assert all(isinstance(b, int) or b is None for b in batch_axes)
109+
mapped = tuple(isinstance(b, int) for b in batch_axes)
110+
if mapped == (True, True, True):
111+
x = vector_arg_values[0].shape[batch_axes[0]]
112+
def squish(val, axis):
113+
dims = ['n', 'l', 'h', 'd']
114+
dims.insert(axis, 'x')
115+
dims = ' '.join(dims)
116+
return einops.rearrange(val, f'{dims} -> (x n) l h d')
117+
def unsquish(val):
118+
return einops.rearrange(val, f'(x n) ... -> x n ...', x=x)
119+
[q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
120+
out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
121+
return (unsquish(out), unsquish(lse)), (0,0)
122+
elif mapped == (True, False, False):
123+
# This is just a GQA!
124+
x = vector_arg_values[0].shape[batch_axes[0]]
125+
def squish(val, axis):
126+
if axis is None:
127+
return val
128+
dims = ['n', 'l', 'h', 'd']
129+
dims.insert(axis, 'x')
130+
dims = ' '.join(dims)
131+
return einops.rearrange(val, f'{dims} -> n l (h x) d')
132+
def unsquish(val):
133+
return einops.rearrange(val, 'n l (h x) d -> x n l h d', x=x)
134+
[q, k, v] = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
135+
out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
136+
out = einops.rearrange(out, 'n l (h x) d -> x n l h d', x=x)
137+
lse = einops.rearrange(lse, 'n (h x) l -> x n h l', x=x)
138+
return (out, lse), (0,0)
139+
else:
140+
raise NotImplementedError("MHA fwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes))
117141

118142
def mha_bwd_batch(vector_arg_values, batch_axes, **kwargs):
119-
assert tuple(batch_axes) == (0,0,0,0,0,0), "Only support vmapping mha over axis 0 for now,"
120-
dout, q, k, v, out, lse = vector_arg_values
121-
b = dout.shape[batch_axes[0]]
122-
def join(*args):
123-
return [rearrange(a, 'b n ... -> (b n) ...') for a in args]
124-
def unjoin(*args):
125-
return [rearrange(a, '(b n) ... -> b n ...', b=b) for a in args]
126-
dq, dk, dv = _flash_mha_bwd_p.bind(*join(dout,q,k,v,out,lse),
127-
**kwargs)
128-
return tuple(unjoin(dq,dk,dv)), (0,0,0)
143+
assert all(isinstance(b, int) or b is None for b in batch_axes)
144+
mapped = tuple(isinstance(b, int) for b in batch_axes)
145+
if mapped == (True, True, True, True, True, True):
146+
x = vector_arg_values[0].shape[batch_axes[0]]
147+
def squish(val, axis):
148+
if len(val.shape) == 5:
149+
# q/k/v/o
150+
dims = ['n', 'l', 'h', 'd']
151+
dims.insert(axis, 'x')
152+
dims = ' '.join(dims)
153+
return einops.rearrange(val, f'{dims} -> (x n) l h d')
154+
elif len(val.shape) == 4:
155+
# lse
156+
dims = ['n', 'h', 'l']
157+
dims.insert(axis, 'x')
158+
dims = ' '.join(dims)
159+
return einops.rearrange(val, f'{dims} -> (x n) h l')
160+
do, q, k, v, o, lse = [squish(x, axis) for x, axis in zip(vector_arg_values, batch_axes)]
161+
dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs)
162+
dq = einops.rearrange(dq, '(n x) l h d -> x n l h d', x=x)
163+
dk = einops.rearrange(dk, '(n x) l h d -> x n l h d', x=x)
164+
dv = einops.rearrange(dv, '(n x) l h d -> x n l h d', x=x)
165+
return (dq,dk,dv), (0,0,0)
166+
elif mapped == (True, True, False, False, True, True):
167+
# Everything is mapped except k and v, which is a GQA backward
168+
x = vector_arg_values[0].shape[batch_axes[0]]
169+
def squish(val, axis):
170+
if len(val.shape) == 5:
171+
# q/k/v/o
172+
dims = ['n', 'l', 'h', 'd']
173+
dims.insert(axis, 'x')
174+
dims = ' '.join(dims)
175+
return einops.rearrange(val, f'{dims} -> n l (h x) d')
176+
elif len(val.shape) == 4:
177+
# lse
178+
dims = ['n', 'h', 'l']
179+
dims.insert(axis, 'x')
180+
dims = ' '.join(dims)
181+
return einops.rearrange(val, f'{dims} -> n (h x) l')
182+
do = squish(vector_arg_values[0], batch_axes[0])
183+
q = squish(vector_arg_values[1], batch_axes[1])
184+
k = vector_arg_values[2]
185+
v = vector_arg_values[3]
186+
o = squish(vector_arg_values[4], batch_axes[4])
187+
lse = squish(vector_arg_values[5], batch_axes[5])
188+
dq, dk, dv = _flash_mha_bwd_p.bind(do, q, k, v, o, lse, **kwargs)
189+
dq = einops.rearrange(dq, 'n l (h x) d -> x n l h d', x=x)
190+
return (dq,dk,dv), (0,None,None)
191+
else:
192+
raise NotImplementedError("MHA bwd only support vmapping over q or (q,k,v) for now, got batch axes " + str(batch_axes))
129193

130194
batching.primitive_batchers[_flash_mha_fwd_p] = mha_fwd_batch
131195
batching.primitive_batchers[_flash_mha_bwd_p] = mha_bwd_batch

Diff for: tests/test_flash.py

+98-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
sys.path.insert(0, glob.glob('build/lib.linux-*')[0])
55
sys.path.insert(0,'./src')
66

7+
from functools import partial
78
import pytest
89
import jax
910
import jax.numpy as jnp
@@ -100,22 +101,114 @@ def test_flash_fwd_vmap(n, seqlen, h, d, causal, local, dtype):
100101
k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32)
101102
v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32)
102103

103-
@jax.jit
104104
def ref(q,k,v):
105105
return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
106-
@jax.jit
107106
def flash(q,k,v):
108107
return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
109108

110-
ref_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)])
109+
ref_out = jax.vmap(ref)(q,k,v)
111110
q = q.astype(dtype)
112111
k = k.astype(dtype)
113112
v = v.astype(dtype)
114-
f16_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)])
115-
113+
f16_out = jax.vmap(ref)(q,k,v)
116114

117115
out = jax.vmap(flash)(q,k,v)
118116
check(ref_out, f16_out, out)
119117

118+
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
119+
@pytest.mark.parametrize("local", ['local',''])
120+
@pytest.mark.parametrize("causal", ['causal',''])
121+
@pytest.mark.parametrize("d", [59, 32])
122+
@pytest.mark.parametrize("h", [1, 4])
123+
@pytest.mark.parametrize("seqlen", [97, 128])
124+
@pytest.mark.parametrize("n", [1])
125+
def test_flash_fwd_vmapq(n, seqlen, h, d, causal, local, dtype):
126+
window_size = (3,3) if local else (-1,-1)
127+
128+
x = 4
129+
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
130+
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
131+
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
132+
133+
def ref(q,k,v):
134+
return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
135+
def flash(q,k,v):
136+
return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
137+
138+
ref_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v)
139+
q = q.astype(dtype)
140+
k = k.astype(dtype)
141+
v = v.astype(dtype)
142+
f16_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v)
143+
144+
out = jax.vmap(flash, in_axes=(0,None,None))(q,k,v)
145+
check(ref_out, f16_out, out)
146+
147+
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
148+
@pytest.mark.parametrize("local", ['local',''])
149+
@pytest.mark.parametrize("causal", ['causal',''])
150+
@pytest.mark.parametrize("d", [59, 32])
151+
@pytest.mark.parametrize("h", [1, 4])
152+
@pytest.mark.parametrize("seqlen", [97, 128])
153+
@pytest.mark.parametrize("n", [1])
154+
def test_flash_bwd_vmap(n, seqlen, h, d, causal, local, dtype):
155+
window_size = (3,3) if local else (-1,-1)
156+
157+
x = 4
158+
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
159+
k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32)
160+
v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32)
161+
do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32)
162+
163+
def func(mha, q,k,v):
164+
@partial(jax.vmap, in_axes=(0,0,0))
165+
def fwd(q,k,v):
166+
return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
167+
o, bwd = jax.vjp(fwd,q,k,v)
168+
return bwd(do)
169+
170+
ref_out = func(ref_mha, q,k,v)
171+
q = q.astype(dtype)
172+
k = k.astype(dtype)
173+
v = v.astype(dtype)
174+
do = do.astype(dtype)
175+
f16_out = func(ref_mha, q,k,v)
176+
177+
out = func(flash_mha, q,k,v)
178+
check(ref_out, f16_out, out)
179+
180+
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
181+
@pytest.mark.parametrize("local", ['local',''])
182+
@pytest.mark.parametrize("causal", ['causal',''])
183+
@pytest.mark.parametrize("d", [59, 32])
184+
@pytest.mark.parametrize("h", [1, 4])
185+
@pytest.mark.parametrize("seqlen", [97, 128])
186+
@pytest.mark.parametrize("n", [1])
187+
def test_flash_bwd_vmapq(n, seqlen, h, d, causal, local, dtype):
188+
window_size = (3,3) if local else (-1,-1)
189+
190+
x = 4
191+
q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32)
192+
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
193+
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
194+
do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32)
195+
196+
def func(mha, q,k,v):
197+
@partial(jax.vmap, in_axes=(0,None,None))
198+
def fwd(q,k,v):
199+
return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
200+
o, bwd = jax.vjp(fwd,q,k,v)
201+
return bwd(do)
202+
203+
ref_out = func(ref_mha, q,k,v)
204+
q = q.astype(dtype)
205+
k = k.astype(dtype)
206+
v = v.astype(dtype)
207+
do = do.astype(dtype)
208+
f16_out = func(ref_mha, q,k,v)
209+
210+
out = func(flash_mha, q,k,v)
211+
check(ref_out, f16_out, out)
212+
120213
if __name__ == '__main__':
121214
test_flash_bwd(1,4,1,32,4,False,False,jnp.float16)

0 commit comments

Comments
 (0)