Skip to content

Commit 92f9691

Browse files
committed
Implement custom sharding for flash_mha, to allow efficient multi-gpu computation when sharded across batch or head dimensions.
1 parent 5f8ff02 commit 92f9691

File tree

3 files changed

+414
-73
lines changed

3 files changed

+414
-73
lines changed

Diff for: lame_test.py

+25-58
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# from flash_attn_jax.flash import flash_mha_fwd, flash_mha_bwd
2020
from flash_attn_jax import flash_mha
2121

22-
2322
if __name__ == '__main__':
2423
import time
2524
import numpy as np
@@ -51,66 +50,34 @@ def pretty(tensor):
5150
def fwd(q,k,v):
5251
return flash_mha(q,k,v)
5352

53+
# print(fwd.lower(q,k,v).as_text())
54+
5455
from jax.sharding import PositionalSharding
5556
from einops import rearrange
5657

57-
sharding = PositionalSharding(jax.devices())
58+
# sharding = PositionalSharding(jax.devices())
59+
devices = jax.devices()
60+
# devices = [*jax.devices(), *jax.devices(backend='cpu')]
61+
n_device = len(devices)
62+
sharding = PositionalSharding(devices).reshape(1,-1,1,1)#.replicate()
63+
64+
65+
# from jax.experimental import mesh_utils
66+
# from jax.sharding import PartitionSpec as P, Mesh
67+
# from jax.sharding import NamedSharding
68+
# devices = np.array(jax.devices()) #mesh_utils.create_device_mesh((1,))
69+
# mesh = Mesh(devices, axis_names=('x',))
70+
# sharding = NamedSharding(mesh, P(None,None,'x',None))
71+
72+
# print(mesh)
73+
74+
o_ref = fwd(q,k,v)
5875

59-
q = jax.device_put(q, sharding.reshape(2,1,1,1))
60-
k = jax.device_put(k, sharding.reshape(2,1,1,1))
61-
v = jax.device_put(v, sharding.reshape(2,1,1,1))
76+
q = jax.device_put(q, sharding)
77+
k = jax.device_put(k, sharding)
78+
v = jax.device_put(v, sharding)
6279
jax.debug.visualize_array_sharding(rearrange(q, 'n l h d -> n (l h d)'))
6380
print(fwd.lower(q,k,v).compile().as_text())
64-
exit()
65-
66-
# print('==== forward ====')
67-
# q = jax.random.normal(jax.random.PRNGKey(0), [32, 4096, 4, 32]).astype(jnp.float16)
68-
# k = jax.random.normal(jax.random.PRNGKey(1), [32, 4096, 4, 32]).astype(jnp.float16)
69-
# v = jax.random.normal(jax.random.PRNGKey(2), [32, 4096, 4, 32]).astype(jnp.float16)
70-
71-
# @jax.jit
72-
# def fwd(q,k,v):
73-
# o = flash_mha(q,k,v)
74-
# for _ in range(32):
75-
# o = flash_mha(q,k,o)
76-
# return o
77-
78-
# @jax.jit
79-
# def fwd_jax(q,k,v):
80-
# ro = pure_mha(q,k,v)
81-
# for _ in range(32):
82-
# ro = pure_mha(q,k,ro)
83-
# return ro
84-
85-
# o = fwd(q,k,v) #, softmax_scale=float(np.sqrt(1/32)))[0]
86-
# start = time.time()
87-
# o = fwd(q,k,v) #, softmax_scale=float(np.sqrt(1/32)))[0]
88-
# print('flash:', time.time() - start, 'seconds')
89-
# ro = fwd_jax(q,k,v)
90-
# start = time.time()
91-
# ro = fwd_jax(q,k,v)
92-
# print('jax:', time.time() - start, 'seconds')
93-
# print(pretty(jnp.abs(o - ro)), jnp.mean(jnp.abs(ro)))
94-
95-
# @jax.jit
96-
# @jax.grad
97-
# def grad_pure(inputs):
98-
# q,k,v = inputs
99-
# return pure_mha(q,k,v).sum()
100-
101-
# @jax.jit
102-
# @jax.grad
103-
# def grad_flash(inputs):
104-
# q,k,v = inputs
105-
# return flash_mha(q,k,v).sum()
106-
107-
# print('==== backward ====')
108-
# q = jax.random.normal(jax.random.PRNGKey(0), [1, 4, 2, 32]).astype(jnp.float16)
109-
# k = jax.random.normal(jax.random.PRNGKey(1), [1, 4, 2, 32]).astype(jnp.float16)
110-
# v = jax.random.normal(jax.random.PRNGKey(2), [1, 4, 2, 32]).astype(jnp.float16)
111-
# dq, dk, dv = grad_flash((q,k,v))
112-
# rdq, rdk, rdv = grad_pure((q,k,v))
113-
# # print(rdq, jnp.mean(jnp.abs(rdq)))
114-
# print('q', pretty(jnp.abs(dq - rdq)), jnp.mean(jnp.abs(rdq)))
115-
# print('k', pretty(jnp.abs(dk - rdk)), jnp.mean(jnp.abs(rdk)))
116-
# print('v', pretty(jnp.abs(dv - rdv)), jnp.mean(jnp.abs(rdv)))
81+
o = fwd(q,k,v)
82+
jax.debug.visualize_array_sharding(rearrange(o, 'n l h d -> n (l h d)'))
83+
print((o - o_ref).std())

Diff for: src/flash_attn_jax/flash.py

+177-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from functools import partial
1+
from functools import partial, wraps
22

33
import jax
44
import jax.numpy as jnp
@@ -10,13 +10,17 @@
1010
from jax.interpreters.mlir import ir
1111
from jax.lib import xla_client
1212
from jaxlib.hlo_helpers import custom_call
13+
from jax.experimental.custom_partitioning import custom_partitioning
14+
1315
from einops import rearrange
1416
import math
1517

1618
import flash_attn_jax.flash_api as flash_api
1719

1820
# ==== Register primitives ====
1921

22+
# We do this with two sets of primitives.
23+
# These are the main ones used, and supports any settings or sharding
2024
_flash_mha_fwd_p = core.Primitive("flash_mha_fwd")
2125
_flash_mha_fwd_p.multiple_results = True
2226
_flash_mha_fwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_p))
@@ -25,8 +29,29 @@
2529
_flash_mha_bwd_p.multiple_results = True
2630
_flash_mha_bwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_p))
2731

28-
29-
def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
32+
# The low level 'cuda' primitives are only used for lowering to hlo,
33+
# and requires d to be padded to a multiple of 8, which we add during
34+
# lowering of the main prims above.
35+
_flash_mha_fwd_cuda_p = core.Primitive("flash_mha_fwd_cuda")
36+
_flash_mha_fwd_cuda_p.multiple_results = True
37+
_flash_mha_fwd_cuda_p.def_impl(partial(xla.apply_primitive, _flash_mha_fwd_cuda_p))
38+
39+
_flash_mha_bwd_cuda_p = core.Primitive("flash_mha_bwd_cuda")
40+
_flash_mha_bwd_cuda_p.multiple_results = True
41+
_flash_mha_bwd_cuda_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_cuda_p))
42+
43+
# @partial(partial, partial)
44+
# def trace(name, func):
45+
# @wraps(func)
46+
# def f(*args, **kwargs):
47+
# print(name, args, kwargs)
48+
# return func(*args, **kwargs)
49+
# return f
50+
51+
# ==== Single shard low level frontend ===
52+
# Adds padding before calling into the cuda primitive.
53+
54+
def _flash_mha_fwd_cuda_1(q, k, v, softmax_scale, is_causal, window_size):
3055
d = q.shape[-1]
3156
assert len(q.shape) == 4
3257
assert d == k.shape[-1]
@@ -37,12 +62,12 @@ def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
3762
q = jnp.pad(q, padding)
3863
k = jnp.pad(k, padding)
3964
v = jnp.pad(v, padding)
40-
out, lse = _flash_mha_fwd_p.bind(q, k, v, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
65+
out, lse = _flash_mha_fwd_cuda_p.bind(q, k, v, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
4166
if d % 8 != 0:
4267
out = out[..., :d]
4368
return out, lse
4469

45-
def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size):
70+
def _flash_mha_bwd_cuda_1(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size):
4671
d = q.shape[-1]
4772
assert len(q.shape) == 4
4873
assert d == k.shape[-1]
@@ -55,11 +80,129 @@ def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size
5580
v = jnp.pad(v, padding)
5681
out = jnp.pad(out, padding)
5782
dout = jnp.pad(dout, padding)
58-
dq, dk, dv = _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
83+
dq, dk, dv = _flash_mha_bwd_cuda_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
5984
if d % 8 != 0:
6085
return dq[...,:d], dk[...,:d], dv[...,:d]
6186
return dq, dk, dv
6287

88+
# ==== Sharding ====
89+
90+
_flash_mha_fwd_cuda = custom_partitioning(_flash_mha_fwd_cuda_1, static_argnums=(3,4,5))
91+
_flash_mha_bwd_cuda = custom_partitioning(_flash_mha_bwd_cuda_1, static_argnums=(6,7,8))
92+
93+
from jax.sharding import PartitionSpec as P
94+
from jax.sharding import Mesh
95+
from jax.sharding import NamedSharding
96+
from jax.sharding import PositionalSharding
97+
98+
# @trace("partition_fwd")
99+
def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
100+
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
101+
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
102+
103+
q_sharding = arg_shardings[0]
104+
if isinstance(q_sharding, PositionalSharding):
105+
if not is_causal and window_size == (-1,-1):
106+
# We can handle Q that's sharded across the L dimension
107+
# without replicating Q by executing it as a cross
108+
# attention:
109+
#
110+
# q : n [L/devices] h d
111+
# kv : n L h d
112+
# -> o : n [L/devices] h d
113+
#
114+
# TODO: We could handle q sharded across L even with
115+
# causal/local if we could communicate the slice offset
116+
# (of q in kv) to the c++ driver. But it's unclear how to
117+
# do that since the HLO has to be identical (SPMD).
118+
q_sharding = q_sharding.replicate(3)
119+
kv_sharding = q_sharding.replicate(1)
120+
(n,l,h,d) = q_sharding.shape
121+
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1) # n h l
122+
arg_shardings = q_sharding, kv_sharding, kv_sharding
123+
else:
124+
# We need to replicate d always.
125+
q_sharding = q_sharding.replicate((1,3))
126+
(n,l,h,d) = q_sharding.shape # l=1, d=1
127+
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1)
128+
arg_shardings = q_sharding, q_sharding, q_sharding
129+
elif isinstance(q_sharding, NamedSharding):
130+
mesh = q_sharding.mesh
131+
[n,l,h,d] = q_sharding.spec
132+
if not is_causal and window_size == (-1,-1):
133+
q_sharding = NamedSharding(mesh, P(n,l,h,None))
134+
kv_sharding = NamedSharding(mesh, P(n,None,h,None))
135+
lse_sharding = NamedSharding(mesh, P(n,h,l))
136+
else:
137+
q_sharding = NamedSharding(mesh, P(n,None,h,None))
138+
kv_sharding = q_sharding
139+
lse_sharding = NamedSharding(mesh, P(n,h,None))
140+
result_sharding = (q_sharding, lse_sharding)
141+
arg_shardings = (q_sharding, kv_sharding, kv_sharding)
142+
def fwd(q,k,v):
143+
return _flash_mha_fwd_cuda_1(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
144+
return mesh, fwd, result_shardings, arg_shardings
145+
146+
# @trace("infer_sharding_fwd")
147+
def infer_sharding_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
148+
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
149+
q_sharding = arg_shardings[0]
150+
if isinstance(q_sharding, PositionalSharding):
151+
[n,l,h,d] = q_sharding.shape
152+
# breakpoint()
153+
result_sharding = (q_sharding, # [n,l,h,d]
154+
q_sharding.replicate(3).reshape(n,l,h).transpose((0,2,1)) # [n,h,l]
155+
)
156+
elif isinstance(q_sharding, NamedSharding):
157+
[n,l,h,d] = q_sharding.spec
158+
result_sharding = (q_sharding,
159+
NamedSharding(q_sharding.mesh, P(n,h,l)))
160+
return result_sharding
161+
162+
_flash_mha_fwd_cuda.def_partition(
163+
infer_sharding_from_operands=infer_sharding_fwd,
164+
partition=partition_fwd)
165+
166+
def infer_sharding_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
167+
# args: dout, q, k, v, out, lse
168+
# outs: dq, dk, dv
169+
# i think generally we want the output sharding for dq,dk,dv to be the same as q,k,v?
170+
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
171+
q_sharding = arg_shardings[1]
172+
k_sharding = arg_shardings[2]
173+
v_sharding = arg_shardings[3]
174+
return q_sharding, k_sharding, v_sharding
175+
176+
def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
177+
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
178+
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
179+
180+
do_sharding = arg_shardings[0]
181+
q_sharding = arg_shardings[1]
182+
k_sharding = arg_shardings[2]
183+
v_sharding = arg_shardings[3]
184+
o_sharding = arg_shardings[4]
185+
lse_sharding = arg_shardings[5]
186+
if isinstance(q_sharding, PositionalSharding):
187+
do_sharding = q_sharding.replicate((1,3))
188+
[n, l, h, d] = do_sharding.shape
189+
lse_sharding = do_sharding.reshape(n,l,h).transpose(0,2,1) # n h l
190+
result_shardings = (do_sharding,)*3
191+
arg_shardings = (do_sharding,)*5 + (lse_sharding,)
192+
elif isinstance(q_sharding, NamedSharding):
193+
mesh = q_sharding.mesh
194+
[n,l,h,d] = q_sharding.spec
195+
do_sharding = NamedSharding(mesh, P(n,None,h,None))
196+
lse_sharding = NamedSharding(mesh, P(n,h,None))
197+
result_shardings = (do_sharding,)*3
198+
def fwd(*args):
199+
return _flash_mha_bwd_cuda_1(*args, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
200+
return mesh, fwd, result_shardings, arg_shardings
201+
202+
_flash_mha_bwd_cuda.def_partition(
203+
infer_sharding_from_operands=infer_sharding_bwd,
204+
partition=partition_bwd)
205+
63206
# ==== CUDA lowerings ====
64207

65208
# Register functions defined in gpu_ops as custom call target for GPUs
@@ -72,7 +215,6 @@ def row_major(shape):
72215
return [row_major(shape) for shape in shapes]
73216

74217
def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is_causal=False, window_size=None):
75-
# print(type(q), dir(q), q.type)
76218
q_type = ir.RankedTensorType(q.type)
77219
q_shape = q_type.shape
78220
k_type = ir.RankedTensorType(k.type)
@@ -115,7 +257,7 @@ def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is
115257
return out
116258

117259
mlir.register_lowering(
118-
_flash_mha_fwd_p,
260+
_flash_mha_fwd_cuda_p,
119261
_flash_mha_fwd_cuda_lowering,
120262
platform="gpu",
121263
)
@@ -176,11 +318,25 @@ def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=Non
176318
return out
177319

178320
mlir.register_lowering(
179-
_flash_mha_bwd_p,
321+
_flash_mha_bwd_cuda_p,
180322
_flash_mha_bwd_cuda_lowering,
181323
platform="gpu",
182324
)
183325

326+
# ==== High level ops ====
327+
328+
mlir.register_lowering(
329+
_flash_mha_fwd_p,
330+
mlir.lower_fun(_flash_mha_fwd_cuda),
331+
platform="gpu",
332+
)
333+
334+
mlir.register_lowering(
335+
_flash_mha_bwd_p,
336+
mlir.lower_fun(_flash_mha_bwd_cuda),
337+
platform="gpu",
338+
)
339+
184340
# ==== Abstract evaluation rules ====
185341

186342
def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=None, window_size=None):
@@ -194,6 +350,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=No
194350
ShapedArray(q.shape, q_dtype, named_shape=q.named_shape),
195351
ShapedArray([n, h, l], jnp.float32)
196352
)
353+
_flash_mha_fwd_cuda_p.def_abstract_eval(_flash_mha_fwd_abstract)
197354
_flash_mha_fwd_p.def_abstract_eval(_flash_mha_fwd_abstract)
198355

199356

@@ -212,6 +369,7 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, d_og=No
212369
ShapedArray(k.shape, k_dtype, named_shape=k.named_shape),
213370
ShapedArray(v.shape, v_dtype, named_shape=v.named_shape),
214371
)
372+
_flash_mha_bwd_cuda_p.def_abstract_eval(_flash_mha_bwd_abstract)
215373
_flash_mha_bwd_p.def_abstract_eval(_flash_mha_bwd_abstract)
216374

217375
# ==== VMap rules ====
@@ -250,19 +408,19 @@ def custom_vjp(cls, nondiff_argnums=()):
250408
f.defvjp(cls.fwd, cls.bwd)
251409
return f
252410

253-
# Apparently we need nondiff_argnums so that softmax_scale doesn't get
254-
# turned into a Tracer, which we can't use as a static parameter. It
255-
# gets placed at the front of the argument list in bwd.
411+
# Apparently we need nondiff_argnums so that config doesn't get turned
412+
# into Tensors. They get placed at the front of the argument list in
413+
# bwd.
256414
@partial(custom_vjp, nondiff_argnums=(3,))
257415
class _flash_mha_vjp:
258416
def base(q,k,v,config):
259-
return flash_mha_fwd(q,k,v, **config)[0]
417+
return _flash_mha_fwd_p.bind(q,k,v, **config)[0]
260418
def fwd(q,k,v,config):
261-
out, lse = flash_mha_fwd(q,k,v, **config)
419+
out, lse = _flash_mha_fwd_p.bind(q,k,v, **config)
262420
return out, (q,k,v,out,lse)
263421
def bwd(config, pack, dout):
264422
(q,k,v,out,lse) = pack
265-
dq, dk, dv = flash_mha_bwd(dout, q, k, v, out, lse, **config)
423+
dq, dk, dv = _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, **config)
266424
return (dq,dk,dv)
267425

268426
# ==== Frontend ====
@@ -274,6 +432,10 @@ def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
274432
provided (ie. can't be a tensor or a tracer).0
275433
276434
"""
435+
assert len(q.shape) == 4
436+
assert len(k.shape) == 4
437+
assert len(v.shape) == 4
438+
277439
if softmax_scale is None:
278440
softmax_scale = 1/math.sqrt(q.shape[-1])
279441
assert type(softmax_scale) is float

0 commit comments

Comments
 (0)