-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathqwen_moe_attention.py
More file actions
376 lines (329 loc) · 13 KB
/
qwen_moe_attention.py
File metadata and controls
376 lines (329 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import inspect
import math
import keras
from keras import ops
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
from keras_hub.src.utils.keras_utils import running_on_gpu
from keras_hub.src.utils.keras_utils import running_on_tpu
class QwenMoeAttention(keras.layers.Layer):
"""A multi-head attention layer for Qwen-Moe model
This attention implementation supports grouped-query attention (GQA) where
the number of key-value heads can be less than the number of query heads.
Args:
num_query_heads: Number of query heads.
num_key_value_heads: Number of key/value heads (for GQA).
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
Embedding).
rope_scaling_factor: Scaling factor for RoPE, used for extending
context length.
kernel_initializer: Initializer for the kernel weights.
bias_initializer: Initializer for the bias weights.
dropout: Dropout rate for attention weights.
use_sliding_window_attention: Whether to use sliding window
attention.
sliding_window_size: Size of the sliding window for attention.
**kwargs: Additional keyword arguments to pass to the Layer.
"""
def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
dropout=0,
use_sliding_window_attention=False,
sliding_window_size=4096,
**kwargs,
):
super().__init__(
**kwargs,
)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.dropout = dropout
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength
self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
self.bias_initializer = keras.initializers.get(
clone_initializer(bias_initializer)
)
self.rope_scaling_factor = rope_scaling_factor
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.logit_soft_cap = None
def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
hidden_dim = inputs_shape[-1]
head_dim = hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, head_dim),
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
bias_axes="uh",
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)
self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
bias_axes="vh",
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)
self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
head_dim,
),
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
bias_axes="vh",
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)
self._softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)
self._dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)
self._output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self._output_dense.build((None, None, self.num_query_heads, head_dim))
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)
self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"
self.built = True
def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
"""Applies attention mechanism to the input hidden states.
Args:
hidden_states: Input tensor of shape [batch_size, seq_length,
hidden_size].
attention_mask: Mask tensor of shape [batch_size, seq_length,
seq_length].
cache: Optional cached key and value tensors.
cache_update_index: Index at which to update the cache.
training: Boolean indicating whether in training mode.
Returns:
attention_output: Output tensor after applying attention.
cache: Updated cache tensors (if cache is provided).
"""
start_index = (
cache_update_index if cache_update_index is not None else 0
)
query = self.query_dense(hidden_states)
# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)
def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)
# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
attention_output = self._compute_attention(
query,
key,
value,
attention_mask,
cache_update_index=cache_update_index,
)
attention_output = self._dropout_layer(
attention_output, training=training
)
attention_output = self._output_dense(attention_output)
if cache is not None:
return attention_output, cache
return attention_output
def _masked_softmax(self, attention_scores, attention_mask=None):
"""Applies softmax with optional masking.
Args:
attention_scores: Attention score tensor.
attention_mask: Optional mask tensor.
Returns:
Masked softmax attention weights.
"""
if attention_mask is not None:
return self._softmax(
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self._softmax(attention_scores)
def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
sig = inspect.signature(ops.dot_product_attention)
return "attn_logits_soft_cap" in sig.parameters
else:
return False
def _compute_attention(
self,
query,
key,
value,
attention_mask=None,
cache_update_index=None,
**kwargs,
):
"""Computes attention using query, key, and value tensors.
Uses Flash Attention when available for better performance.
Args:
query: Query tensor.
key: Key tensor.
value: Value tensor.
attention_mask: Optional mask tensor.
cache_update_index: Index for sliding window computation.
Returns:
attention_output: Output tensor after applying attention.
"""
if self._use_fused_attention_op():
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
attention_output = ops.dot_product_attention(
query,
key,
value,
mask=attention_mask,
scale=self._inv_norm_factor,
**kwargs,
)
return attention_output
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)
if self.use_sliding_window_attention:
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index
if cache_update_index
else 0,
)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)
return attention_output
def _mask_sliding_window(
self,
attention_mask,
cache_update_index=0,
):
"""Creates and combines a sliding window mask with the attention mask.
Args:
attention_mask: Original attention mask.
cache_update_index: Starting index for the sliding window.
Returns:
Combined attention mask with sliding window constraints.
"""
_, query_len, key_len = ops.shape(attention_mask)
# Compute the sliding window for square attention.
all_ones = ops.ones((key_len, key_len), "bool")
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
# Slice the window for short queries during generation.
start = (cache_update_index, 0)
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
sliding_mask = ops.expand_dims(sliding_mask, 0)
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"bias_initializer": keras.initializers.serialize(
self.bias_initializer
),
"dropout": self.dropout,
"use_sliding_window_attention": (
self.use_sliding_window_attention
),
"sliding_window_size": self.sliding_window_size,
}
)
return config