@@ -112,34 +112,31 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
112
112
n_devices = xr .global_runtime_device_count ()
113
113
xs .set_global_mesh (xs .Mesh (range (n_devices ), (n_devices , 1 , 1 , 1 )))
114
114
115
- q = torch .randn (4 , 2 , 128 , 4 ).to ("xla" )
116
- k = torch .randn (4 , 2 , 128 , 4 ).to ("xla" )
117
- v = torch .randn (4 , 2 , 128 , 4 ).to ("xla" )
118
- q_segment_ids = torch .ones (4 , 128 , device = q .device , dtype = torch .float32 ).to ("xla" )
119
- kv_segment_ids = torch .rand (4 , 128 ).to ("xla" )
115
+ q = torch .randn (16 , 32 , 2048 , 64 ).to ("xla" )
116
+ k = torch .randn (16 , 32 , 128 , 64 ).to ("xla" )
117
+ v = torch .randn (16 , 32 , 128 , 64 ).to ("xla" )
118
+ q_segment_ids = torch .ones (16 , 2048 , dtype = torch .float32 ).to ("xla" )
119
+ kv_segment_ids = torch .zeros (16 , 1 , 128 , dtype = torch .float32 ).to ("xla" )
120
+ kv_segment_ids [:8 , :, 30 :] = - 10000.0
121
+ kv_segment_ids [8 :, :, 60 :] = - 10000.0
120
122
121
123
o = flash_attention (q , k , v , q_segment_ids , kv_segment_ids , partition_spec = range (4 ))
122
124
self .assertEqual (
123
125
torch_xla ._XLAC ._get_xla_sharding_spec (o ),
124
126
f"{{devices=[{ n_devices } ,1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}" )
125
127
126
- attention_mask = F .pad (kv_segment_ids , (0 , 16256 ), value = 0.0 )
127
- attention_mask = attention_mask .repeat_interleave (2 , dim = 0 )
128
- attention_mask = attention_mask .view (4 , 2 , 128 , 128 )
129
- # attention_mask = torch.ones(4, 2, 128, 128).to("xla")
130
- # head_size = self.heads
131
- # current_length: int = attention_mask.shape[-1]
132
- # if current_length != target_length:
133
- # attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
134
-
135
- # if attention_mask.shape[0] < 4 * head_size:
136
- # attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
137
- #
138
- # attention_mask = attention_mask.view(
139
- # batch_size, attn.heads, -1, attention_mask.shape[-1]
140
- # )
128
+ attention_mask = kv_segment_ids .repeat_interleave (32 , dim = 0 )
129
+ attention_mask = attention_mask .view (16 , 32 , 1 , 128 )
141
130
142
131
expected_o = self ._attention (q , k , v , attn_mask = attention_mask )
132
+ # expected_o = F.scaled_dot_product_attention(
133
+ # q,
134
+ # k,
135
+ # v,
136
+ # attn_mask=attention_mask,
137
+ # dropout_p=0.0,
138
+ # is_causal=False,
139
+ # )
143
140
diff = (expected_o - o ).abs ()
144
141
# z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)
145
142
0 commit comments