-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathcached_multi_head_attention.py
More file actions
247 lines (215 loc) · 9.67 KB
/
cached_multi_head_attention.py
File metadata and controls
247 lines (215 loc) · 9.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import keras
from keras import ops
from keras_hub.src.api_export import keras_hub_export
# Check if SDPA is available for the PyTorch backend.
_TORCH_SDPA_AVAILABLE = None
def _check_torch_sdpa():
global _TORCH_SDPA_AVAILABLE
if _TORCH_SDPA_AVAILABLE is None:
try:
import torch.nn.functional as F
_TORCH_SDPA_AVAILABLE = hasattr(F, "scaled_dot_product_attention")
except ImportError:
_TORCH_SDPA_AVAILABLE = False
return _TORCH_SDPA_AVAILABLE
@keras_hub_export("keras_hub.layers.CachedMultiHeadAttention")
class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
"""MultiHeadAttention layer with cache support.
This layer is suitable for use in autoregressive decoding. It can be used
to cache decoder self-attention and cross-attention. The forward pass
can happen in one of three modes:
- No cache, same as regular multi-head attention.
- Static cache (`cache_update_index` is None). In this case, the
cached key/value projections will be used and the input values will
be ignored.
- Updated cache (`cache_update_index` is not None). In this case, new
key/value projections are computed using the input, and spliced into
the cache at the specified index.
Note that caching is useful only during inference and should not be used
during training.
We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
`T` is the target sequence length, and `S` in the source sequence length.
Note that during generative decoding, `T` is usually 1 (you are
generating a target sequence of length one to predict the next token).
Call arguments:
query: Query `Tensor` of shape `(B, T, dim)`.
value: Value `Tensor` of shape `(B, S*, dim)`. if `cache` is None`, `S*`
must equal `S` and match the shape of `attention_mask`. If cache` is
not `None`, `S*` can be any length less than `S`, and the computed
value will be spliced into `cache` at `cache_update_index`.
key: Optional key `Tensor` of shape `(B, S*, dim)`. If `cache` is
`None`, `S*` must equal `S` and match the shape of
`attention_mask`. If `cache` is not `None`, `S*` can be any length
less than `S`, and the computed value will be spliced into `cache`
at `cache_update_index`.
attention_mask: a boolean mask of shape `(B, T, S)`. `attention_mask`
prevents attention to certain positions. The boolean mask specifies
which query elements can attend to which key elements, 1 indicates
attention and 0 indicates no attention. Broadcasting can happen for
the missing batch dimensions and the head dimension.
cache: a dense float Tensor. The key/value cache, of shape
`[B, 2, S, num_heads, key_dims]`, where `S` must agree with the
`attention_mask` shape. This argument is intended for use during
generation to avoid recomputing intermediate state.
cache_update_index: a int or int Tensor, the index at which to update
`cache` (usually the index of the current token being processed
when running generation). If `cache_update_index=None` while `cache`
is set, the cache will not be updated.
training: a boolean indicating whether the layer should behave in
training mode or in inference mode.
Returns:
An `(attention_output, cache)` tuple. `attention_output` is the result
of the computation, of shape `(B, T, dim)`, where `T` is for target
sequence shapes and `dim` is the query input last dimension if
`output_shape` is `None`. Otherwise, the multi-head outputs are
projected to the shape specified by `output_shape`. `cache` is the
updated cache.
"""
def call(
self,
query,
value,
key=None,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
if key is None:
key = value
query = self._query_dense(query)
# If cache is not `None`, we will use the cache to compute the final key
# and value tensors. If `cache_update_index` is not None, we will first
# update the cache before use. To do this, we first call the
# `_key_dense` and `_value_dense` layers, and copy the outputs into the
# cache at the specified index. `cache = None` handles the training
# case, where we don't use the cache at all.
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update = self._key_dense(key)
value_update = self._value_dense(value)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key = self._key_dense(key)
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
)
attention_output = self._output_dense(attention_output)
if cache is not None:
return attention_output, cache
return attention_output
def call_cached(
self,
query,
attention_mask=None,
cache=None,
cache_update_index=None,
):
"""Ultra-fast path for cached autoregressive decoding.
Bypasses Layer.__call__ overhead on all sublayers by calling
.call() directly. This is safe because:
- All sublayers are already built
- Input dtypes are already correct (same dtype flows through)
- No masking metadata needed
- No training-mode checks needed (always inference)
- No autocast scope changes needed
This saves ~5 Layer.__call__ invocations per attention layer
(query_dense, key_dense, value_dense, output_dense, plus the
attention layer itself).
"""
# Directly call .call() on dense layers, bypassing Layer.__call__
query_proj = self._query_dense.call(query)
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update = self._key_dense.call(query)
value_update = self._value_dense.call(query)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
attention_output, _ = self._compute_attention(
query=query_proj,
key=key,
value=value,
attention_mask=attention_mask,
training=False,
)
attention_output = self._output_dense.call(attention_output)
return attention_output, cache
def _compute_attention(
self,
query,
key,
value,
attention_mask=None,
training=None,
return_attention_scores=False,
):
"""Override to use SDPA during cached inference on torch.
Only activated when `_use_sdpa_override` is set True (by the
TransformerDecoder.call_cached fast path for self-attention).
Falls back to the parent implementation otherwise.
"""
if (
keras.config.backend() == "torch"
and not return_attention_scores
and (training is None or training is False)
and len(query.shape) == 4
and _check_torch_sdpa()
and getattr(self, "_use_sdpa_override", False)
):
import torch
import torch.nn.functional as F
# Transpose from (B, S, H, D) to (B, H, S, D) for SDPA.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Convert attention mask to SDPA format.
# Both Keras and PyTorch SDPA use the same convention for bool
# masks: True = attend, False = mask out.
# No inversion needed - pass through directly.
if attention_mask is not None:
attention_mask = attention_mask.to(dtype=torch.bool)
while attention_mask.dim() < 4:
attention_mask = attention_mask.unsqueeze(1)
attention_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
)
# Transpose back from (B, H, T, D) to (B, T, H, D).
attention_output = attention_output.transpose(1, 2)
return attention_output, None
return super()._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
return_attention_scores=return_attention_scores,
)