Skip to content

Commit bc9a01d

Browse files
committed
Implement ring attention backward pass. More tests.
1 parent 9f80518 commit bc9a01d

File tree

4 files changed

+177
-81
lines changed

4 files changed

+177
-81
lines changed

Diff for: README.md

+14-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ Please cite (see below) and credit FlashAttention if you use it.
1111
## Installation and features
1212

1313
Requirements:
14-
- CUDA 11.6 and above.
14+
- CUDA 11.8 and above.
1515
- Linux. Same story as with the pytorch repo. I haven't tested compilation of the jax bindings on windows.
16+
- JAX >=`0.4.24`. The custom sharding used for ring attention requires some somewhat advanced features.
1617

17-
To install: TODO
18+
To install: For now, download the appropriate release from the releases page and install it with pip.
1819

1920
Interface: `src/flash_attn_jax/flash.py`
2021

@@ -28,6 +29,17 @@ Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_sc
2829
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
2930
to positive values for sliding window attention.
3031

32+
### Now Supports Ring Attention
33+
34+
Use jax.Array and shard your tensors along the length dimension, and flash_mha will automatically use the ring attention algorithm:
35+
36+
```py
37+
with Mesh(devices, axis_names=('len',)) as mesh:
38+
sharding = NamedSharding(mesh, P(None,'len',None)) # n l d
39+
tokens = jax.device_put(tokens, sharding)
40+
# invoke your jax.jit'd transformer.forward
41+
```
42+
3143
FlashAttention-2 currently supports:
3244
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
3345
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing

Diff for: src/flash_attn_jax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .flash import flash_mha
2-
__version__ = 'v2.5.0'
2+
__version__ = 'v2.5.5'

Diff for: src/flash_attn_jax/flash_sharding.py

+123-55
Original file line numberDiff line numberDiff line change
@@ -30,53 +30,6 @@
3030

3131
from jax._src.ad_checkpoint import _optimization_barrier
3232

33-
def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
34-
[n,l,h,d] = q.shape
35-
36-
q_ix = jax.lax.axis_index(axis_name)
37-
k_ix = jax.lax.axis_index(axis_name)
38-
39-
o = jnp.zeros([n,l,h,d], jnp.float32)
40-
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)
41-
42-
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
43-
def f(c, a):
44-
(k, v, o, lse, k_ix) = c
45-
46-
o1, lse1 = o, lse
47-
if is_causal:
48-
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
49-
[
50-
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
51-
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
52-
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
53-
], q, k, v)
54-
else:
55-
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
56-
o2 = o2.astype(jnp.float32)
57-
58-
mx = jnp.maximum(lse1,lse2)
59-
mn = jnp.minimum(lse1,lse2)
60-
lse = jnp.log1p(jnp.exp(mn-mx)) + mx
61-
62-
o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
63-
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))
64-
65-
k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
66-
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
67-
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
68-
69-
return ((k2, v2, o, lse, k_ix), None)
70-
acc = (k,v,o,lse,k_ix)
71-
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
72-
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
73-
for _ in range(axis_size):
74-
acc, _ = f(acc, None)
75-
acc = _optimization_barrier(acc)
76-
(_,_,o,lse,_) = acc
77-
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
78-
return o.astype(q.dtype), lse
79-
8033
def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
8134
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
8235
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
@@ -147,21 +100,136 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
147100
o_sharding = arg_shardings[4]
148101
lse_sharding = arg_shardings[5]
149102
if isinstance(q_sharding, PositionalSharding):
150-
do_sharding = q_sharding.replicate((1,3))
151-
[n, l, h, d] = do_sharding.shape
152-
lse_sharding = do_sharding.reshape(n,l,h).transpose(0,2,1) # n h l
153-
result_shardings = (do_sharding,)*3
154-
arg_shardings = (do_sharding,)*5 + (lse_sharding,)
103+
assert q_sharding == k_sharding, "Expect q and k sharding to match"
104+
assert q_sharding == v_sharding, "Expect q and v sharding to match"
105+
[n, l, h, d] = q_sharding.shape
106+
assert d == 1, "Sharding across `d` won't be efficient, so it's not supported."
107+
assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
108+
lse_sharding = q_sharding.reshape(n,h,1) # n h l
109+
result_shardings = (q_sharding,)*3
110+
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
155111
elif isinstance(q_sharding, NamedSharding):
156112
mesh = q_sharding.mesh
157113
[n,l,h,d] = q_sharding.spec
158-
do_sharding = NamedSharding(mesh, P(n,None,h,None))
159-
lse_sharding = NamedSharding(mesh, P(n,h,None))
160-
result_shardings = (do_sharding,)*3
114+
assert d == None, "Sharding across `d` won't be efficient, so it's not supported."
115+
if l != None:
116+
# assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet."
117+
assert window_size == (-1,-1), "Ring attention doesn't support local masking yet."
118+
result_shardings = q_sharding, q_sharding, q_sharding
119+
lse_sharding = NamedSharding(mesh, P(n,h,l))
120+
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
121+
axis_name = l
122+
axis_size = mesh.shape[axis_name]
123+
# ring attention
124+
return mesh, partial(ring_bwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings
125+
else:
126+
result_shardings = q_sharding, q_sharding, q_sharding
127+
lse_sharding = NamedSharding(mesh, P(n,h,l))
128+
arg_shardings = (q_sharding,)*5 + (lse_sharding,)
161129
def fwd(*args):
162130
return _flash_mha_bwd_hlo(*args, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
163131
return mesh, fwd, result_shardings, arg_shardings
164132

165133
_flash_mha_bwd_hlo_sharded.def_partition(
166134
infer_sharding_from_operands=infer_sharding_bwd,
167135
partition=partition_bwd)
136+
137+
# ==== Ring Forward ====
138+
139+
def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
140+
[n,l,h,d] = q.shape
141+
142+
q_ix = jax.lax.axis_index(axis_name)
143+
k_ix = jax.lax.axis_index(axis_name)
144+
145+
o = jnp.zeros([n,l,h,d], jnp.float32)
146+
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)
147+
148+
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
149+
def f(c, a):
150+
(k, v, o, lse, k_ix) = c
151+
152+
o1, lse1 = o, lse
153+
if is_causal:
154+
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
155+
[
156+
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
157+
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
158+
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
159+
], q, k, v)
160+
else:
161+
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
162+
o2 = o2.astype(jnp.float32)
163+
164+
mx = jnp.maximum(lse1,lse2)
165+
mn = jnp.minimum(lse1,lse2)
166+
lse = jnp.log1p(jnp.exp(mn-mx)) + mx
167+
168+
o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
169+
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))
170+
171+
k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
172+
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
173+
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
174+
175+
return ((k2, v2, o, lse, k_ix), None)
176+
acc = (k,v,o,lse,k_ix)
177+
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
178+
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
179+
for _ in range(axis_size):
180+
acc, _ = f(acc, None)
181+
acc = _optimization_barrier(acc)
182+
(_,_,o,lse,_) = acc
183+
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
184+
return o.astype(q.dtype), lse
185+
186+
# ==== Ring Backward ===
187+
188+
# This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice.
189+
# Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device.
190+
def ring_bwd(softmax_scale, is_causal, axis_name, axis_size, do,q,k,v,o,lse):
191+
[n,l,h,d] = q.shape
192+
193+
ix = jax.lax.axis_index(axis_name)
194+
195+
dq = jnp.zeros([n,l,h,d], jnp.float32)
196+
dk = jnp.zeros([n,l,h,d], jnp.float32)
197+
dv = jnp.zeros([n,l,h,d], jnp.float32)
198+
199+
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
200+
def f(acc, a):
201+
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc
202+
203+
cmp = (ix2 < ix).astype(jnp.int32) + (ix2 <= ix).astype(jnp.int32)
204+
# 0: ix < ix2
205+
# 1: ix = ix2
206+
# 2: ix > ix2
207+
if is_causal:
208+
dqa = jax.lax.switch(cmp, [
209+
lambda q,k,v: jnp.zeros([n,l,h,d], q.dtype),
210+
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[0],
211+
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[0],
212+
], q, k, v)
213+
dka,dva = jax.lax.switch(cmp, [
214+
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[1:],
215+
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[1:],
216+
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype),jnp.zeros([n,l,h,d], q.dtype)),
217+
], q, k, v)
218+
else:
219+
dqa,_,_ = _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
220+
_,dka,dva = _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
221+
222+
dq += dqa
223+
dk += dka
224+
dv += dva
225+
226+
(do2,q2,k2,v2,o2,lse2,ix2) = jax.lax.ppermute((do2,q2,k2,v2,o2,lse2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
227+
228+
return ((do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv), None)
229+
acc = (do,q,k,v,o,lse,ix,dq,dk,dv)
230+
# Unrolled as above.
231+
for _ in range(axis_size):
232+
acc, _ = f(acc, None)
233+
acc = _optimization_barrier(acc)
234+
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc
235+
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)

Diff for: tests/test_sharding.py

+39-23
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def with_sharding(q_sharding, kv_sharding=None):
103103
@pytest.mark.parametrize("d", [32])
104104
@pytest.mark.parametrize("h", [4])
105105
@pytest.mark.parametrize("seqlen", [128])
106-
@pytest.mark.parametrize("shard_dim", [0,2])
107-
def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim):
106+
def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype):
108107
window_size = (3,3) if local else (-1,-1)
109108

110109
devices = jax.local_devices()[:4]
@@ -117,19 +116,35 @@ def test_flash_bwd_sharded_hlo(seqlen, h, d, causal, local, dtype, shard_dim):
117116
def flash(qkv):
118117
return (flash_mha(*qkv, is_causal=bool(causal), window_size=window_size)**2).sum()
119118

120-
q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype)
121-
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype)
122-
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype)
119+
def with_sharding(sharding):
120+
q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h, d], dtype=dtype)
121+
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=dtype)
122+
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=dtype)
123+
(q,k,v) = jax.device_put((q,k,v), sharding)
124+
hlo = flash.lower((q,k,v)).compile().as_text()
125+
return hlo
123126

124-
shape = [1,1,1,1]
125-
shape[shard_dim] = n
126-
sharding = PositionalSharding(devices).reshape(shape)
127+
hlo = with_sharding(PositionalSharding(devices).reshape(n,1,1,1))
128+
assert 'all-gather' not in hlo
129+
assert 'dynamic-slice' not in hlo
127130

128-
q,k,v = jax.device_put((q,k,v), sharding)
129-
hlo = flash.lower((q,k,v)).compile().as_text()
131+
hlo = with_sharding(PositionalSharding(devices).reshape(1,1,n,1))
130132
assert 'all-gather' not in hlo
131133
assert 'dynamic-slice' not in hlo
132134

135+
if not local:
136+
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
137+
sharding = NamedSharding(mesh, P(None,'x',None,None))
138+
hlo = with_sharding(sharding)
139+
# No resharding should occur, only manual collective-permute.
140+
assert 'all-gather' not in hlo
141+
assert 'dynamic-slice' not in hlo
142+
assert 'collective-permute' in hlo
143+
# Should always run concurrently, meaning custom-call is always between start and done.
144+
import re
145+
collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
146+
assert 'collective-permute-start collective-permute-done' not in collectives, hlo
147+
133148
@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
134149
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
135150
@pytest.mark.parametrize("local", ['local',''])
@@ -181,8 +196,7 @@ def check_sharding(sharding,q,k,v):
181196
@pytest.mark.parametrize("d", [32])
182197
@pytest.mark.parametrize("h", [4, 8])
183198
@pytest.mark.parametrize("seqlen", [128])
184-
@pytest.mark.parametrize("shard_dim", [0,2])
185-
def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype, shard_dim):
199+
def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype):
186200
window_size = (3,3) if local else (-1,-1)
187201

188202
devices = jax.local_devices()
@@ -200,23 +214,25 @@ def flash(qkv):
200214
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
201215
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
202216

203-
if q.shape[shard_dim] % n != 0:
204-
pytest.skip(f"{q.shape[shard_dim]} doesn't divide into {n} so we can't shard it.")
205-
206217
ref_out = ref((q,k,v))
207218
q = q.astype(dtype)
208219
k = k.astype(dtype)
209220
v = v.astype(dtype)
210-
repl_out = flash((q,k,v))
221+
ref16_out = flash((q,k,v))
222+
223+
def check_sharding(sharding,q,k,v):
224+
(q,k,v) = jax.device_put((q,k,v), sharding)
225+
out = flash((q,k,v))
226+
check(ref_out,ref16_out,out)
211227

212-
shape = [1,1,1,1]
213-
shape[shard_dim] = n
214-
sharding = PositionalSharding(devices).reshape(shape)
228+
check_sharding(PositionalSharding(devices).reshape(n,1,1,1),q,k,v)
229+
check_sharding(PositionalSharding(devices).reshape(1,1,n,1),q,k,v)
215230

216-
(q,k,v) = jax.device_put((q,k,v), sharding)
217-
hlo = flash.lower((q,k,v)).compile().as_text()
218-
out = flash((q,k,v))
219-
check(ref_out, repl_out, out)
231+
if not local:
232+
# Ring attention
233+
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
234+
sharding = NamedSharding(mesh, P(None,'x',None,None))
235+
check_sharding(sharding,q,k,v)
220236

221237
if __name__ == '__main__':
222238
test_flash_fwd_sharded_hlo(128,4,32,False,False,jnp.float16)

0 commit comments

Comments
 (0)