-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathmixtral_attention.py
More file actions
254 lines (219 loc) · 8.55 KB
/
mixtral_attention.py
File metadata and controls
254 lines (219 loc) · 8.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import inspect
import math
import keras
from keras import ops
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
from keras_hub.src.utils.keras_utils import running_on_gpu
from keras_hub.src.utils.keras_utils import running_on_tpu
class CachedMixtralAttention(keras.layers.Layer):
"""A cached grounded query attention layer with sliding window."""
def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=512,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength
self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
self.rope_scaling_factor = rope_scaling_factor
def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)
self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)
self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)
self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)
self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)
self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"
self.built = True
def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)
query = self.query_dense(hidden_states)
# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)
def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)
# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
attention_output = self._compute_attention(
query, key, value, attention_mask
)
attention_output = self.dropout_layer(
attention_output, training=training
)
attention_output = self.output_dense(attention_output)
if cache is not None:
return attention_output, cache
return attention_output
def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self.softmax(
attention_scores, ops.expand_dims(attention_mask, axis=1)
)
return self.softmax(attention_scores)
def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
sig = inspect.signature(ops.dot_product_attention)
return "attn_logits_soft_cap" in sig.parameters
else:
return False
def _compute_attention(self, query, key, value, attention_mask=None):
if self._use_fused_attention_op():
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
attention_output = ops.dot_product_attention(
query,
key,
value,
mask=attention_mask,
scale=self._inv_norm_factor,
)
return attention_output
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)
return attention_output
def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self.sliding_window,
"dropout": self.dropout,
}
)
return config