30
30
31
31
from jax ._src .ad_checkpoint import _optimization_barrier
32
32
33
- def ring_fwd (softmax_scale , is_causal , axis_name , axis_size , q ,k ,v ):
34
- [n ,l ,h ,d ] = q .shape
35
-
36
- q_ix = jax .lax .axis_index (axis_name )
37
- k_ix = jax .lax .axis_index (axis_name )
38
-
39
- o = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
40
- lse = jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )
41
-
42
- # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
43
- def f (c , a ):
44
- (k , v , o , lse , k_ix ) = c
45
-
46
- o1 , lse1 = o , lse
47
- if is_causal :
48
- o2 , lse2 = jax .lax .switch ((k_ix < q_ix ).astype (jnp .int32 ) + (k_ix <= q_ix ).astype (jnp .int32 ),
49
- [
50
- lambda q ,k ,v : (jnp .zeros ([n ,l ,h ,d ], q .dtype ), jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )),
51
- lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 )),
52
- lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 )),
53
- ], q , k , v )
54
- else :
55
- o2 , lse2 = _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
56
- o2 = o2 .astype (jnp .float32 )
57
-
58
- mx = jnp .maximum (lse1 ,lse2 )
59
- mn = jnp .minimum (lse1 ,lse2 )
60
- lse = jnp .log1p (jnp .exp (mn - mx )) + mx
61
-
62
- o = (o1 * rearrange (jnp .exp (lse1 - lse ), 'n h l -> n l h 1' ) +
63
- o2 * rearrange (jnp .exp (lse2 - lse ), 'n h l -> n l h 1' ))
64
-
65
- k2 = jax .lax .ppermute (k , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
66
- v2 = jax .lax .ppermute (v , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
67
- k_ix = jax .lax .ppermute (k_ix , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
68
-
69
- return ((k2 , v2 , o , lse , k_ix ), None )
70
- acc = (k ,v ,o ,lse ,k_ix )
71
- # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
72
- # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
73
- for _ in range (axis_size ):
74
- acc , _ = f (acc , None )
75
- acc = _optimization_barrier (acc )
76
- (_ ,_ ,o ,lse ,_ ) = acc
77
- # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
78
- return o .astype (q .dtype ), lse
79
-
80
33
def partition_fwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
81
34
result_shardings = jax .tree_map (lambda x : x .sharding , result_shape )
82
35
arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
@@ -147,21 +100,136 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
147
100
o_sharding = arg_shardings [4 ]
148
101
lse_sharding = arg_shardings [5 ]
149
102
if isinstance (q_sharding , PositionalSharding ):
150
- do_sharding = q_sharding .replicate ((1 ,3 ))
151
- [n , l , h , d ] = do_sharding .shape
152
- lse_sharding = do_sharding .reshape (n ,l ,h ).transpose (0 ,2 ,1 ) # n h l
153
- result_shardings = (do_sharding ,)* 3
154
- arg_shardings = (do_sharding ,)* 5 + (lse_sharding ,)
103
+ assert q_sharding == k_sharding , "Expect q and k sharding to match"
104
+ assert q_sharding == v_sharding , "Expect q and v sharding to match"
105
+ [n , l , h , d ] = q_sharding .shape
106
+ assert d == 1 , "Sharding across `d` won't be efficient, so it's not supported."
107
+ assert l == 1 , "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
108
+ lse_sharding = q_sharding .reshape (n ,h ,1 ) # n h l
109
+ result_shardings = (q_sharding ,)* 3
110
+ arg_shardings = (q_sharding ,)* 5 + (lse_sharding ,)
155
111
elif isinstance (q_sharding , NamedSharding ):
156
112
mesh = q_sharding .mesh
157
113
[n ,l ,h ,d ] = q_sharding .spec
158
- do_sharding = NamedSharding (mesh , P (n ,None ,h ,None ))
159
- lse_sharding = NamedSharding (mesh , P (n ,h ,None ))
160
- result_shardings = (do_sharding ,)* 3
114
+ assert d == None , "Sharding across `d` won't be efficient, so it's not supported."
115
+ if l != None :
116
+ # assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet."
117
+ assert window_size == (- 1 ,- 1 ), "Ring attention doesn't support local masking yet."
118
+ result_shardings = q_sharding , q_sharding , q_sharding
119
+ lse_sharding = NamedSharding (mesh , P (n ,h ,l ))
120
+ arg_shardings = (q_sharding ,)* 5 + (lse_sharding ,)
121
+ axis_name = l
122
+ axis_size = mesh .shape [axis_name ]
123
+ # ring attention
124
+ return mesh , partial (ring_bwd , softmax_scale , is_causal , axis_name , axis_size ), result_shardings , arg_shardings
125
+ else :
126
+ result_shardings = q_sharding , q_sharding , q_sharding
127
+ lse_sharding = NamedSharding (mesh , P (n ,h ,l ))
128
+ arg_shardings = (q_sharding ,)* 5 + (lse_sharding ,)
161
129
def fwd (* args ):
162
130
return _flash_mha_bwd_hlo (* args , softmax_scale = softmax_scale , is_causal = is_causal , window_size = window_size )
163
131
return mesh , fwd , result_shardings , arg_shardings
164
132
165
133
_flash_mha_bwd_hlo_sharded .def_partition (
166
134
infer_sharding_from_operands = infer_sharding_bwd ,
167
135
partition = partition_bwd )
136
+
137
+ # ==== Ring Forward ====
138
+
139
+ def ring_fwd (softmax_scale , is_causal , axis_name , axis_size , q ,k ,v ):
140
+ [n ,l ,h ,d ] = q .shape
141
+
142
+ q_ix = jax .lax .axis_index (axis_name )
143
+ k_ix = jax .lax .axis_index (axis_name )
144
+
145
+ o = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
146
+ lse = jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )
147
+
148
+ # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
149
+ def f (c , a ):
150
+ (k , v , o , lse , k_ix ) = c
151
+
152
+ o1 , lse1 = o , lse
153
+ if is_causal :
154
+ o2 , lse2 = jax .lax .switch ((k_ix < q_ix ).astype (jnp .int32 ) + (k_ix <= q_ix ).astype (jnp .int32 ),
155
+ [
156
+ lambda q ,k ,v : (jnp .zeros ([n ,l ,h ,d ], q .dtype ), jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )),
157
+ lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 )),
158
+ lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 )),
159
+ ], q , k , v )
160
+ else :
161
+ o2 , lse2 = _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
162
+ o2 = o2 .astype (jnp .float32 )
163
+
164
+ mx = jnp .maximum (lse1 ,lse2 )
165
+ mn = jnp .minimum (lse1 ,lse2 )
166
+ lse = jnp .log1p (jnp .exp (mn - mx )) + mx
167
+
168
+ o = (o1 * rearrange (jnp .exp (lse1 - lse ), 'n h l -> n l h 1' ) +
169
+ o2 * rearrange (jnp .exp (lse2 - lse ), 'n h l -> n l h 1' ))
170
+
171
+ k2 = jax .lax .ppermute (k , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
172
+ v2 = jax .lax .ppermute (v , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
173
+ k_ix = jax .lax .ppermute (k_ix , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
174
+
175
+ return ((k2 , v2 , o , lse , k_ix ), None )
176
+ acc = (k ,v ,o ,lse ,k_ix )
177
+ # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
178
+ # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
179
+ for _ in range (axis_size ):
180
+ acc , _ = f (acc , None )
181
+ acc = _optimization_barrier (acc )
182
+ (_ ,_ ,o ,lse ,_ ) = acc
183
+ # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
184
+ return o .astype (q .dtype ), lse
185
+
186
+ # ==== Ring Backward ===
187
+
188
+ # This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice.
189
+ # Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device.
190
+ def ring_bwd (softmax_scale , is_causal , axis_name , axis_size , do ,q ,k ,v ,o ,lse ):
191
+ [n ,l ,h ,d ] = q .shape
192
+
193
+ ix = jax .lax .axis_index (axis_name )
194
+
195
+ dq = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
196
+ dk = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
197
+ dv = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
198
+
199
+ # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
200
+ def f (acc , a ):
201
+ (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ) = acc
202
+
203
+ cmp = (ix2 < ix ).astype (jnp .int32 ) + (ix2 <= ix ).astype (jnp .int32 )
204
+ # 0: ix < ix2
205
+ # 1: ix = ix2
206
+ # 2: ix > ix2
207
+ if is_causal :
208
+ dqa = jax .lax .switch (cmp , [
209
+ lambda q ,k ,v : jnp .zeros ([n ,l ,h ,d ], q .dtype ),
210
+ lambda q ,k ,v : _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 ))[0 ],
211
+ lambda q ,k ,v : _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))[0 ],
212
+ ], q , k , v )
213
+ dka ,dva = jax .lax .switch (cmp , [
214
+ lambda q ,k ,v : _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))[1 :],
215
+ lambda q ,k ,v : _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 ))[1 :],
216
+ lambda q ,k ,v : (jnp .zeros ([n ,l ,h ,d ], q .dtype ),jnp .zeros ([n ,l ,h ,d ], q .dtype )),
217
+ ], q , k , v )
218
+ else :
219
+ dqa ,_ ,_ = _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
220
+ _ ,dka ,dva = _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
221
+
222
+ dq += dqa
223
+ dk += dka
224
+ dv += dva
225
+
226
+ (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 ) = jax .lax .ppermute ((do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 ), axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
227
+
228
+ return ((do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ), None )
229
+ acc = (do ,q ,k ,v ,o ,lse ,ix ,dq ,dk ,dv )
230
+ # Unrolled as above.
231
+ for _ in range (axis_size ):
232
+ acc , _ = f (acc , None )
233
+ acc = _optimization_barrier (acc )
234
+ (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ) = acc
235
+ return dq .astype (q .dtype ),dk .astype (q .dtype ),dv .astype (q .dtype )
0 commit comments