-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathtransformer_layer_utils.py
More file actions
104 lines (92 loc) · 4 KB
/
transformer_layer_utils.py
File metadata and controls
104 lines (92 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from absl import logging
from keras import ops
from keras.src.backend import get_keras_mask
def _check_masks_shapes(inputs, padding_mask, attention_mask):
mask = padding_mask
if mask is None:
mask = get_keras_mask(inputs)
if mask is not None:
if len(mask.shape) != 2:
raise ValueError(
"`padding_mask` should have shape "
"(batch_size, target_length). "
f"Received shape `{mask.shape}`."
)
if attention_mask is not None:
if len(attention_mask.shape) != 3:
raise ValueError(
"`attention_mask` should have shape "
"(batch_size, target_length, source_length). "
f"Received shape `{mask.shape}`."
)
def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
"""Compute a causal attention mask for a transformer decoder.
Args:
batch_size: batch size for the mask.
input_length: the length of key/value tensors in the attention layer.
output_length: the length of query tensors in the attention layer.
cache_index: the current index for cached generation. If passed, the
query sequence will be considered to start at `cache_index` rather
than zero. For example, a causal mask with `output_length=1` and
`cache_index=5` would allow the query tensor to attend to the first
five positions of the key/value tensors.
Return:
A causal attention mask with shape
`(batch_size, output_length, input_length)` that can be passed to a
attention layer.
"""
# Fast path for autoregressive generation: when output_length=1 (single
# token), the causal mask is simply True for all positions up to
# cache_index and False after. This avoids ops.arange/expand_dims/
# broadcast_to overhead that is significant when called 12×46 times.
if isinstance(output_length, int) and output_length == 1:
j = ops.arange(input_length, dtype="float32")
mask = ops.expand_dims(
ops.expand_dims(j <= ops.cast(cache_index, "float32"), axis=0),
axis=0,
)
return ops.broadcast_to(mask, (batch_size, 1, input_length))
i = ops.arange(output_length, dtype="float32")
i = i + ops.cast(cache_index, "float32")
i = ops.expand_dims(i, axis=1)
j = ops.arange(input_length, dtype="float32")
mask = ops.expand_dims(i >= j, axis=0)
return ops.broadcast_to(mask, (batch_size, output_length, input_length))
def merge_padding_and_attention_mask(
inputs,
padding_mask,
attention_mask,
):
"""Merge the padding mask with a customized attention mask.
Args:
inputs: the input sequence.
padding_mask: the 1D padding mask, of shape
[batch_size, sequence_length].
attention_mask: the 2D customized mask, of shape
[batch_size, sequence1_length, sequence2_length].
Return:
A merged 2D mask or None. If only `padding_mask` is provided, the
returned mask is padding_mask with one additional axis.
"""
_check_masks_shapes(inputs, padding_mask, attention_mask)
# We look for a padding mask from the input data.
mask = get_keras_mask(inputs)
# But if padding mask is explicitly provided, we use it.
if padding_mask is not None:
if mask is not None:
logging.warning(
"You are explicitly setting `padding_mask` while the `inputs` "
"have built-in mask, so the built-in mask is ignored."
)
mask = padding_mask
if mask is not None:
# Add an axis for broadcasting, the attention mask should be 2D
# (not including the batch axis).
mask = ops.cast(ops.expand_dims(mask, axis=1), "int32")
if attention_mask is not None:
attention_mask = ops.cast(attention_mask, "int32")
if mask is None:
return attention_mask
else:
return ops.minimum(mask, attention_mask)
return mask