-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathtest_blockmask.py
More file actions
233 lines (211 loc) · 9.24 KB
/
test_blockmask.py
File metadata and controls
233 lines (211 loc) · 9.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import math
import itertools
import pytest
from einops import rearrange, repeat
import paddle
import time
from paddle.nn.functional.flash_attention import flashmask_attention
from generate_startend_row_indices import (
startend_row_indices_to_attn_bias,
generate_none_mask,
generate_sliding_window_mask,
generate_causal_document_mask,
generate_document_mask,
generate_share_question_mask,
generate_global_sliding_window_mask,
generate_causal_blockwise_mask,
generate_prefix_lm_document_mask,
generate_prefix_lm_causal_mask,
generate_qk_sparse_mask,
generate_random_eviction_mask
)
from functools import partial
from test_util import attention_ref, blockmask_to_densemask, random_blockmask, flashmask_to_densemask
# batch_size, seqlen_q, seqlen_k, nheads, nheads_kv
shape_cases = (
[
(28, 128, 128, 16, 4),
(4, 256, 256, 4, 1),
# (2, 8192, 32768, 32, 4), # this will oom
# (2, 8192, 8192, 32, 4), # this will oom
(1, 8192, 8192, 1, 1),
# (2, 16384, 16384, 1, 1),
(1, 128, 128, 1, 1),
(1, 127, 128, 1, 1),
(1, 16384, 16384, 1, 1),
# (2, 16384, 16383, 4, 1),
# my case
]
# tridao case
+ list(itertools.product(
[1], # batch_size
[1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q
[128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k
[1,2], # nheads
[1], # nheads_kv
))
+ list(itertools.product(
[2], # batch_size
[4096, 4224], # seqlen_q
[4096, 4224], # seqlen_k
[6], # nheads
[6, 2, 1], # nheads_kv
))
)
# Generate all combinations for second param
def generate_shapes():
for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases:
nheads_startend_row_indices_values = [1, nheads_kv]
for nheads_startend_row_indices in nheads_startend_row_indices_values:
yield (
batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices
)
@pytest.mark.parametrize("dtype", [paddle.bfloat16])
@pytest.mark.parametrize("fa_version", [3])
@pytest.mark.parametrize("d, dv", [(128, 128)])
@pytest.mark.parametrize(
"batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices",
list(generate_shapes())
)
@pytest.mark.parametrize(
"gen_startend_row_indices",
[
# partial(generate_none_mask, causal=False), # full
# partial(generate_none_mask, causal=True), # causal
partial(generate_sliding_window_mask), # sliding window
partial(generate_causal_document_mask), # causal document mask
partial(generate_document_mask), # document mask
partial(generate_share_question_mask), # share question mask
partial(generate_global_sliding_window_mask), # global sliding window
partial(generate_causal_blockwise_mask), # causal blockwise mask
partial(generate_prefix_lm_document_mask), # prefix lm document mask
partial(generate_prefix_lm_causal_mask), # prefix lm causal mask
partial(generate_qk_sparse_mask), # qk-sparse mask
partial(generate_random_eviction_mask), # random eviction mask
],
)
def test_flashmask(
batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0
):
paddle.seed(2024)
assert nheads % nheads_kv == 0
q_ref = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype)
# print(q_ref)
k_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype)
v_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype)
q_ref.stop_gradient = False
k_ref.stop_gradient = False
v_ref.stop_gradient = False
q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)]
q_bf16.stop_gradient = False
k_bf16.stop_gradient = False
v_bf16.stop_gradient = False
q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)]
# print(q_ref)
q.stop_gradient = False
k.stop_gradient = False
v.stop_gradient = False
startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices)
if startend_row_indices is None and causal and d == 80:
pytest.skip(f"Skipping because running headdim 80 with flash_attn in causal mask")
# print(q_ref)
print(k_ref.shape)
blockmask = random_blockmask(
shape=[
startend_row_indices.shape[0],
startend_row_indices.shape[1],
(seqlen_q + 127)// 128,
(seqlen_k + 127)// 128
],
dtype=paddle.int32,
is_causal=causal,
ref_q = q_ref
)
# print(q_ref)
# paddle.save(q, 'query.pd')
# paddle.save(k, 'key.pd')
# paddle.save(v, 'value.pd')
# paddle.save(blockmask, 'blockmask.pd')
# paddle.save(startend_row_indices, 'startend_row_indices.pd')
mask_flash = flashmask_to_densemask(startend_row_indices, seqlen_q, nheads_startend_row_indices, causal)
mask_block = blockmask_to_densemask(blockmask,seqlen_q,seqlen_k,paddle.int32,causal)
mask_inf = mask_flash & mask_block
# print(mask_inf)
attn_bias = paddle.zeros((batch_size, nheads_startend_row_indices, seqlen_q, seqlen_k), dtype=paddle.bfloat16)
attn_bias = paddle.where(mask_inf, paddle.zeros_like(attn_bias), paddle.full_like(attn_bias, float('-inf')))
paddle.save(attn_bias, 'attn_bias.pd')
# time.sleep(0.1)
out_ref, attn_ref = attention_ref(
q_ref,
k_ref,
v_ref,
causal=causal,
attn_bias=attn_bias
)
out_bf16, attn_bf16 = attention_ref(
q_bf16,
k_bf16,
v_bf16,
causal=causal,
attn_bias=attn_bias,
upcast=False,
reorder_ops=True
)
# # Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
assert softcap == 0.0
rtol = 2 if softcap == 0.0 else 3
print(f"Paddle naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}")
print(f"Paddle naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}")
if fa_version == 2:
paddle.set_flags({'FLAGS_flash_attn_version': 2})
elif fa_version == 3:
paddle.set_flags({'FLAGS_flash_attn_version': 3})
else:
raise ValueError(
f"Invalid flash attention version: {fa_version}"
)
out, lse = flashmask_attention(
q,
k,
v,
startend_row_indices=startend_row_indices,
causal=causal,
return_softmax_lse=True,
block_mask=blockmask
)
print(f"flashmask output max at {(out - out_ref).abs().argmax()}")
print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}")
print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}")
# if not causal:
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint()
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol
# return
g = paddle.randn(shape=out.shape, dtype=out.dtype)
paddle.save(g, 'g.pd')
out.backward(g)
out_ref.backward(g)
out_bf16.backward(g)
paddle.device.synchronize()
print(f"flashmask dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}")
print(f"flashmask dK max diff: {(k.grad - k_ref.grad).abs().max().item()}")
print(f"flashmask dV max diff: {(v.grad - v_ref.grad).abs().max().item()}")
print(f"flashmask dQ mean diff: {(q.grad - q_ref.grad).abs().mean().item()}")
print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}")
print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}")
print(f"Paddle naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}")
print(f"Paddle naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}")
print(f"Paddle naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}")
print(f"Paddle naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}")
print(f"Paddle naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}")
print(f"Paddle naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}")
dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol
dk_atol = 2 * (k_ref.grad + 0.3 - 0.3 - k_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (k.grad - k_ref.grad).abs().max().item() <= rtol * (k_bf16.grad - k_ref.grad).abs().max().item() + dk_atol
dv_atol = 2 * (v_ref.grad + 0.3 - 0.3 - v_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (v.grad - v_ref.grad).abs().max().item() <= rtol * (v_bf16.grad - v_ref.grad).abs().max().item() + dv_atol