@@ -40,7 +40,7 @@ def _attention(self, q, k, v, *, attn_mask=None, ab=None):
40
40
attn_weight = attn_weight .masked_fill (attn_mask ,
41
41
torch .finfo (attn_weight .dtype ).min )
42
42
if ab is not None :
43
- attn_weight = attn_weight + ab
43
+ attn_weight = attn_weight + ab
44
44
attn_weight = nn .functional .softmax (attn_weight , dim = - 1 )
45
45
attn_output = attn_weight @ v
46
46
return attn_output
@@ -139,7 +139,7 @@ def test_flash_attention_wrapper_segment_ids_spmd(self):
139
139
partition_spec = ("data" , None , None , None ))
140
140
self .assertEqual (
141
141
torch_xla ._XLAC ._get_xla_sharding_spec (o ),
142
- f"{{devices=[{ xr .global_runtime_device_count ()} ,1,1,1]0,1,2,3}}" )
142
+ f"{{devices=[{ xr .global_runtime_device_count ()} ,1,1,1]0,1,2,3}}" )
143
143
144
144
jax_q = jnp .array (q .numpy (), dtype = jnp .float32 )
145
145
jax_k = jnp .array (k .numpy (), dtype = jnp .float32 )
@@ -175,12 +175,19 @@ def test_flash_attention_backward_segment_ids_spmd(self):
175
175
k .retain_grad ()
176
176
v .retain_grad ()
177
177
178
- o = flash_attention (q , k , v , False , segment_ids , segment_ids , partition_spec = ("data" , None , None , None ))
178
+ o = flash_attention (
179
+ q ,
180
+ k ,
181
+ v ,
182
+ False ,
183
+ segment_ids ,
184
+ segment_ids ,
185
+ partition_spec = ("data" , None , None , None ))
179
186
loss = o .sum ()
180
187
loss .backward ()
181
188
q_grad = q .grad
182
189
k_grad = k .grad
183
- v_grad = v .grad
190
+ v_grad = v .grad
184
191
self .assertEqual (
185
192
torch_xla ._XLAC ._get_xla_sharding_spec (o ),
186
193
f"{{devices=[{ n_devices } ,1,1,1]0,1,2,3}}" )
@@ -192,10 +199,9 @@ def test_flash_attention_backward_segment_ids_spmd(self):
192
199
f"{{devices=[{ n_devices } ,1,1,1]0,1,2,3}}" )
193
200
self .assertEqual (
194
201
torch_xla ._XLAC ._get_xla_sharding_spec (v_grad ),
195
- f"{{devices=[{ n_devices } ,1,1,1]0,1,2,3}}" )
202
+ f"{{devices=[{ n_devices } ,1,1,1]0,1,2,3}}" )
196
203
torch_xla .sync ()
197
204
198
-
199
205
torch .manual_seed (42 )
200
206
q = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
201
207
k = torch .randn (4 , 2 , 128 , 8 , requires_grad = True ).to ("xla" )
@@ -220,6 +226,7 @@ def test_flash_attention_backward_segment_ids_spmd(self):
220
226
self .assertTrue (torch .allclose (i [0 ].grad .cpu (), i [1 ].cpu (), atol = 1e-05 ))
221
227
jax .config .update ("jax_default_matmul_precision" , "default" )
222
228
229
+
223
230
if __name__ == '__main__' :
224
231
logging .getLogger ().setLevel (logging .INFO )
225
232
torch .set_default_dtype (torch .float32 )
0 commit comments