1
- from functools import partial
1
+ from functools import partial , wraps
2
2
3
3
import jax
4
4
import jax .numpy as jnp
10
10
from jax .interpreters .mlir import ir
11
11
from jax .lib import xla_client
12
12
from jaxlib .hlo_helpers import custom_call
13
+ from jax .experimental .custom_partitioning import custom_partitioning
14
+
13
15
from einops import rearrange
14
16
import math
15
17
16
18
import flash_attn_jax .flash_api as flash_api
17
19
18
20
# ==== Register primitives ====
19
21
22
+ # We do this with two sets of primitives.
23
+ # These are the main ones used, and supports any settings or sharding
20
24
_flash_mha_fwd_p = core .Primitive ("flash_mha_fwd" )
21
25
_flash_mha_fwd_p .multiple_results = True
22
26
_flash_mha_fwd_p .def_impl (partial (xla .apply_primitive , _flash_mha_fwd_p ))
25
29
_flash_mha_bwd_p .multiple_results = True
26
30
_flash_mha_bwd_p .def_impl (partial (xla .apply_primitive , _flash_mha_bwd_p ))
27
31
28
-
29
- def flash_mha_fwd (q , k , v , softmax_scale , is_causal , window_size ):
32
+ # The low level 'cuda' primitives are only used for lowering to hlo,
33
+ # and requires d to be padded to a multiple of 8, which we add during
34
+ # lowering of the main prims above.
35
+ _flash_mha_fwd_cuda_p = core .Primitive ("flash_mha_fwd_cuda" )
36
+ _flash_mha_fwd_cuda_p .multiple_results = True
37
+ _flash_mha_fwd_cuda_p .def_impl (partial (xla .apply_primitive , _flash_mha_fwd_cuda_p ))
38
+
39
+ _flash_mha_bwd_cuda_p = core .Primitive ("flash_mha_bwd_cuda" )
40
+ _flash_mha_bwd_cuda_p .multiple_results = True
41
+ _flash_mha_bwd_cuda_p .def_impl (partial (xla .apply_primitive , _flash_mha_bwd_cuda_p ))
42
+
43
+ # @partial(partial, partial)
44
+ # def trace(name, func):
45
+ # @wraps(func)
46
+ # def f(*args, **kwargs):
47
+ # print(name, args, kwargs)
48
+ # return func(*args, **kwargs)
49
+ # return f
50
+
51
+ # ==== Single shard low level frontend ===
52
+ # Adds padding before calling into the cuda primitive.
53
+
54
+ def _flash_mha_fwd_cuda_1 (q , k , v , softmax_scale , is_causal , window_size ):
30
55
d = q .shape [- 1 ]
31
56
assert len (q .shape ) == 4
32
57
assert d == k .shape [- 1 ]
@@ -37,12 +62,12 @@ def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
37
62
q = jnp .pad (q , padding )
38
63
k = jnp .pad (k , padding )
39
64
v = jnp .pad (v , padding )
40
- out , lse = _flash_mha_fwd_p .bind (q , k , v , softmax_scale = softmax_scale , d_og = d , is_causal = is_causal , window_size = window_size )
65
+ out , lse = _flash_mha_fwd_cuda_p .bind (q , k , v , softmax_scale = softmax_scale , d_og = d , is_causal = is_causal , window_size = window_size )
41
66
if d % 8 != 0 :
42
67
out = out [..., :d ]
43
68
return out , lse
44
69
45
- def flash_mha_bwd (dout , q , k , v , out , lse , softmax_scale , is_causal , window_size ):
70
+ def _flash_mha_bwd_cuda_1 (dout , q , k , v , out , lse , softmax_scale , is_causal , window_size ):
46
71
d = q .shape [- 1 ]
47
72
assert len (q .shape ) == 4
48
73
assert d == k .shape [- 1 ]
@@ -55,11 +80,129 @@ def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size
55
80
v = jnp .pad (v , padding )
56
81
out = jnp .pad (out , padding )
57
82
dout = jnp .pad (dout , padding )
58
- dq , dk , dv = _flash_mha_bwd_p .bind (dout , q , k , v , out , lse , softmax_scale = softmax_scale , d_og = d , is_causal = is_causal , window_size = window_size )
83
+ dq , dk , dv = _flash_mha_bwd_cuda_p .bind (dout , q , k , v , out , lse , softmax_scale = softmax_scale , d_og = d , is_causal = is_causal , window_size = window_size )
59
84
if d % 8 != 0 :
60
85
return dq [...,:d ], dk [...,:d ], dv [...,:d ]
61
86
return dq , dk , dv
62
87
88
+ # ==== Sharding ====
89
+
90
+ _flash_mha_fwd_cuda = custom_partitioning (_flash_mha_fwd_cuda_1 , static_argnums = (3 ,4 ,5 ))
91
+ _flash_mha_bwd_cuda = custom_partitioning (_flash_mha_bwd_cuda_1 , static_argnums = (6 ,7 ,8 ))
92
+
93
+ from jax .sharding import PartitionSpec as P
94
+ from jax .sharding import Mesh
95
+ from jax .sharding import NamedSharding
96
+ from jax .sharding import PositionalSharding
97
+
98
+ # @trace("partition_fwd")
99
+ def partition_fwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
100
+ result_shardings = jax .tree_map (lambda x : x .sharding , result_shape )
101
+ arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
102
+
103
+ q_sharding = arg_shardings [0 ]
104
+ if isinstance (q_sharding , PositionalSharding ):
105
+ if not is_causal and window_size == (- 1 ,- 1 ):
106
+ # We can handle Q that's sharded across the L dimension
107
+ # without replicating Q by executing it as a cross
108
+ # attention:
109
+ #
110
+ # q : n [L/devices] h d
111
+ # kv : n L h d
112
+ # -> o : n [L/devices] h d
113
+ #
114
+ # TODO: We could handle q sharded across L even with
115
+ # causal/local if we could communicate the slice offset
116
+ # (of q in kv) to the c++ driver. But it's unclear how to
117
+ # do that since the HLO has to be identical (SPMD).
118
+ q_sharding = q_sharding .replicate (3 )
119
+ kv_sharding = q_sharding .replicate (1 )
120
+ (n ,l ,h ,d ) = q_sharding .shape
121
+ result_shardings = q_sharding , q_sharding .reshape ((n ,l ,h )).transpose (0 ,2 ,1 ) # n h l
122
+ arg_shardings = q_sharding , kv_sharding , kv_sharding
123
+ else :
124
+ # We need to replicate d always.
125
+ q_sharding = q_sharding .replicate ((1 ,3 ))
126
+ (n ,l ,h ,d ) = q_sharding .shape # l=1, d=1
127
+ result_shardings = q_sharding , q_sharding .reshape ((n ,l ,h )).transpose (0 ,2 ,1 )
128
+ arg_shardings = q_sharding , q_sharding , q_sharding
129
+ elif isinstance (q_sharding , NamedSharding ):
130
+ mesh = q_sharding .mesh
131
+ [n ,l ,h ,d ] = q_sharding .spec
132
+ if not is_causal and window_size == (- 1 ,- 1 ):
133
+ q_sharding = NamedSharding (mesh , P (n ,l ,h ,None ))
134
+ kv_sharding = NamedSharding (mesh , P (n ,None ,h ,None ))
135
+ lse_sharding = NamedSharding (mesh , P (n ,h ,l ))
136
+ else :
137
+ q_sharding = NamedSharding (mesh , P (n ,None ,h ,None ))
138
+ kv_sharding = q_sharding
139
+ lse_sharding = NamedSharding (mesh , P (n ,h ,None ))
140
+ result_sharding = (q_sharding , lse_sharding )
141
+ arg_shardings = (q_sharding , kv_sharding , kv_sharding )
142
+ def fwd (q ,k ,v ):
143
+ return _flash_mha_fwd_cuda_1 (q ,k ,v , softmax_scale = softmax_scale , is_causal = is_causal , window_size = window_size )
144
+ return mesh , fwd , result_shardings , arg_shardings
145
+
146
+ # @trace("infer_sharding_fwd")
147
+ def infer_sharding_fwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
148
+ arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
149
+ q_sharding = arg_shardings [0 ]
150
+ if isinstance (q_sharding , PositionalSharding ):
151
+ [n ,l ,h ,d ] = q_sharding .shape
152
+ # breakpoint()
153
+ result_sharding = (q_sharding , # [n,l,h,d]
154
+ q_sharding .replicate (3 ).reshape (n ,l ,h ).transpose ((0 ,2 ,1 )) # [n,h,l]
155
+ )
156
+ elif isinstance (q_sharding , NamedSharding ):
157
+ [n ,l ,h ,d ] = q_sharding .spec
158
+ result_sharding = (q_sharding ,
159
+ NamedSharding (q_sharding .mesh , P (n ,h ,l )))
160
+ return result_sharding
161
+
162
+ _flash_mha_fwd_cuda .def_partition (
163
+ infer_sharding_from_operands = infer_sharding_fwd ,
164
+ partition = partition_fwd )
165
+
166
+ def infer_sharding_bwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
167
+ # args: dout, q, k, v, out, lse
168
+ # outs: dq, dk, dv
169
+ # i think generally we want the output sharding for dq,dk,dv to be the same as q,k,v?
170
+ arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
171
+ q_sharding = arg_shardings [1 ]
172
+ k_sharding = arg_shardings [2 ]
173
+ v_sharding = arg_shardings [3 ]
174
+ return q_sharding , k_sharding , v_sharding
175
+
176
+ def partition_bwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
177
+ result_shardings = jax .tree_map (lambda x : x .sharding , result_shape )
178
+ arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
179
+
180
+ do_sharding = arg_shardings [0 ]
181
+ q_sharding = arg_shardings [1 ]
182
+ k_sharding = arg_shardings [2 ]
183
+ v_sharding = arg_shardings [3 ]
184
+ o_sharding = arg_shardings [4 ]
185
+ lse_sharding = arg_shardings [5 ]
186
+ if isinstance (q_sharding , PositionalSharding ):
187
+ do_sharding = q_sharding .replicate ((1 ,3 ))
188
+ [n , l , h , d ] = do_sharding .shape
189
+ lse_sharding = do_sharding .reshape (n ,l ,h ).transpose (0 ,2 ,1 ) # n h l
190
+ result_shardings = (do_sharding ,)* 3
191
+ arg_shardings = (do_sharding ,)* 5 + (lse_sharding ,)
192
+ elif isinstance (q_sharding , NamedSharding ):
193
+ mesh = q_sharding .mesh
194
+ [n ,l ,h ,d ] = q_sharding .spec
195
+ do_sharding = NamedSharding (mesh , P (n ,None ,h ,None ))
196
+ lse_sharding = NamedSharding (mesh , P (n ,h ,None ))
197
+ result_shardings = (do_sharding ,)* 3
198
+ def fwd (* args ):
199
+ return _flash_mha_bwd_cuda_1 (* args , softmax_scale = softmax_scale , is_causal = is_causal , window_size = window_size )
200
+ return mesh , fwd , result_shardings , arg_shardings
201
+
202
+ _flash_mha_bwd_cuda .def_partition (
203
+ infer_sharding_from_operands = infer_sharding_bwd ,
204
+ partition = partition_bwd )
205
+
63
206
# ==== CUDA lowerings ====
64
207
65
208
# Register functions defined in gpu_ops as custom call target for GPUs
@@ -72,7 +215,6 @@ def row_major(shape):
72
215
return [row_major (shape ) for shape in shapes ]
73
216
74
217
def _flash_mha_fwd_cuda_lowering (ctx , q , k , v , softmax_scale = None , d_og = None , is_causal = False , window_size = None ):
75
- # print(type(q), dir(q), q.type)
76
218
q_type = ir .RankedTensorType (q .type )
77
219
q_shape = q_type .shape
78
220
k_type = ir .RankedTensorType (k .type )
@@ -115,7 +257,7 @@ def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is
115
257
return out
116
258
117
259
mlir .register_lowering (
118
- _flash_mha_fwd_p ,
260
+ _flash_mha_fwd_cuda_p ,
119
261
_flash_mha_fwd_cuda_lowering ,
120
262
platform = "gpu" ,
121
263
)
@@ -176,11 +318,25 @@ def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=Non
176
318
return out
177
319
178
320
mlir .register_lowering (
179
- _flash_mha_bwd_p ,
321
+ _flash_mha_bwd_cuda_p ,
180
322
_flash_mha_bwd_cuda_lowering ,
181
323
platform = "gpu" ,
182
324
)
183
325
326
+ # ==== High level ops ====
327
+
328
+ mlir .register_lowering (
329
+ _flash_mha_fwd_p ,
330
+ mlir .lower_fun (_flash_mha_fwd_cuda ),
331
+ platform = "gpu" ,
332
+ )
333
+
334
+ mlir .register_lowering (
335
+ _flash_mha_bwd_p ,
336
+ mlir .lower_fun (_flash_mha_bwd_cuda ),
337
+ platform = "gpu" ,
338
+ )
339
+
184
340
# ==== Abstract evaluation rules ====
185
341
186
342
def _flash_mha_fwd_abstract (q , k , v , softmax_scale = None , d_og = None , is_causal = None , window_size = None ):
@@ -194,6 +350,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=No
194
350
ShapedArray (q .shape , q_dtype , named_shape = q .named_shape ),
195
351
ShapedArray ([n , h , l ], jnp .float32 )
196
352
)
353
+ _flash_mha_fwd_cuda_p .def_abstract_eval (_flash_mha_fwd_abstract )
197
354
_flash_mha_fwd_p .def_abstract_eval (_flash_mha_fwd_abstract )
198
355
199
356
@@ -212,6 +369,7 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, d_og=No
212
369
ShapedArray (k .shape , k_dtype , named_shape = k .named_shape ),
213
370
ShapedArray (v .shape , v_dtype , named_shape = v .named_shape ),
214
371
)
372
+ _flash_mha_bwd_cuda_p .def_abstract_eval (_flash_mha_bwd_abstract )
215
373
_flash_mha_bwd_p .def_abstract_eval (_flash_mha_bwd_abstract )
216
374
217
375
# ==== VMap rules ====
@@ -250,19 +408,19 @@ def custom_vjp(cls, nondiff_argnums=()):
250
408
f .defvjp (cls .fwd , cls .bwd )
251
409
return f
252
410
253
- # Apparently we need nondiff_argnums so that softmax_scale doesn't get
254
- # turned into a Tracer, which we can't use as a static parameter. It
255
- # gets placed at the front of the argument list in bwd.
411
+ # Apparently we need nondiff_argnums so that config doesn't get turned
412
+ # into Tensors. They get placed at the front of the argument list in
413
+ # bwd.
256
414
@partial (custom_vjp , nondiff_argnums = (3 ,))
257
415
class _flash_mha_vjp :
258
416
def base (q ,k ,v ,config ):
259
- return flash_mha_fwd (q ,k ,v , ** config )[0 ]
417
+ return _flash_mha_fwd_p . bind (q ,k ,v , ** config )[0 ]
260
418
def fwd (q ,k ,v ,config ):
261
- out , lse = flash_mha_fwd (q ,k ,v , ** config )
419
+ out , lse = _flash_mha_fwd_p . bind (q ,k ,v , ** config )
262
420
return out , (q ,k ,v ,out ,lse )
263
421
def bwd (config , pack , dout ):
264
422
(q ,k ,v ,out ,lse ) = pack
265
- dq , dk , dv = flash_mha_bwd (dout , q , k , v , out , lse , ** config )
423
+ dq , dk , dv = _flash_mha_bwd_p . bind (dout , q , k , v , out , lse , ** config )
266
424
return (dq ,dk ,dv )
267
425
268
426
# ==== Frontend ====
@@ -274,6 +432,10 @@ def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
274
432
provided (ie. can't be a tensor or a tracer).0
275
433
276
434
"""
435
+ assert len (q .shape ) == 4
436
+ assert len (k .shape ) == 4
437
+ assert len (v .shape ) == 4
438
+
277
439
if softmax_scale is None :
278
440
softmax_scale = 1 / math .sqrt (q .shape [- 1 ])
279
441
assert type (softmax_scale ) is float
0 commit comments