-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathposition_embedding.py
More file actions
140 lines (122 loc) · 5.18 KB
/
position_embedding.py
File metadata and controls
140 lines (122 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import keras
from keras import ops
from keras_hub.src.api_export import keras_hub_export
@keras_hub_export("keras_hub.layers.PositionEmbedding")
class PositionEmbedding(keras.layers.Layer):
"""A layer which learns a position embedding for inputs sequences.
This class assumes that in the input tensor, the last dimension corresponds
to the features, and the dimension before the last corresponds to the
sequence.
This layer does not supporting masking, but can be combined with a
`keras.layers.Embedding` for padding mask support.
Args:
sequence_length: The maximum length of the dynamic sequence.
initializer: The initializer to use for the embedding weights. Defaults
to `"glorot_uniform"`.
seq_axis: The axis of the input tensor where we add the embeddings.
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `trainable`, `dtype` etc.
Call arguments:
inputs: The tensor inputs to compute an embedding for, with shape
`(batch_size, sequence_length, hidden_dim)`. Only the input shape
will be used, as the position embedding does not depend on the
input sequence content.
start_index: An integer or integer tensor. The starting position to
compute the position embedding from. This is useful during cached
decoding, where each position is predicted separately in a loop.
positions: Tensor of shape `(sequence_length,)` or
`(batch_size, sequence_length)`. Custom positions for the input
sequence. If specified, this tensor will be used to
compute the position embedding, and the `start_index` argument will
be ignored. This is useful for cases with non-standard positions.
Example:
Called directly on input.
>>> layer = keras_hub.layers.PositionEmbedding(sequence_length=10)
>>> layer(np.zeros((8, 10, 16)))
Combine with a token embedding.
```python
seq_length = 50
vocab_size = 5000
embed_dim = 128
inputs = keras.Input(shape=(seq_length,))
token_embeddings = keras.layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)(inputs)
position_embeddings = keras_hub.layers.PositionEmbedding(
sequence_length=seq_length
)(token_embeddings)
outputs = token_embeddings + position_embeddings
```
Reference:
- [Devlin et al., 2019](https://arxiv.org/abs/1810.04805)
"""
def __init__(
self,
sequence_length,
initializer="glorot_uniform",
**kwargs,
):
super().__init__(**kwargs)
if sequence_length is None:
raise ValueError(
"`sequence_length` must be an Integer, received `None`."
)
self.sequence_length = int(sequence_length)
self.initializer = keras.initializers.get(initializer)
def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"initializer": keras.initializers.serialize(self.initializer),
}
)
return config
def build(self, inputs_shape):
feature_size = inputs_shape[-1]
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.sequence_length, feature_size],
initializer=self.initializer,
trainable=True,
)
self.built = True
def call(self, inputs, start_index=0, positions=None):
shape = ops.shape(inputs)
feature_length = shape[-1]
sequence_length = shape[-2]
# trim to match the length of the input sequence, which might be less
# than the sequence_length of the layer.
position_embeddings = ops.convert_to_tensor(self.position_embeddings)
if positions is None:
# Fast path for single-token cached decoding on torch: use direct
# indexing instead of ops.slice to avoid overhead.
# Only applies when both sequence_length and start_index are
# static Python ints (not traced values like in JAX JIT).
if (
isinstance(sequence_length, int)
and sequence_length == 1
and isinstance(start_index, int)
):
position_embeddings = position_embeddings[
start_index : start_index + 1, :
]
position_embeddings = ops.expand_dims(
position_embeddings, axis=0
)
else:
position_embeddings = ops.slice(
position_embeddings,
(start_index, 0),
(sequence_length, feature_length),
)
else:
# Take care of unbatched `positions`.
if len(ops.shape(positions)) == 1:
positions = ops.expand_dims(positions, axis=0)
position_embeddings = ops.take(
position_embeddings, positions, axis=0
)
return ops.broadcast_to(position_embeddings, shape)
def compute_output_shape(self, input_shape):
return input_shape