This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathtransformer_encoder.py
More file actions
152 lines (132 loc) · 5.14 KB
/
transformer_encoder.py
File metadata and controls
152 lines (132 loc) · 5.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
from keras import layers
from keras_cv.api_export import keras_cv_export
@keras_cv_export("keras_cv.layers.TransformerEncoder")
class TransformerEncoder(layers.Layer):
"""
Transformer encoder block implementation as a Keras Layer.
Args:
project_dim: the dimensionality of the projection of the encoder, and
output of the `MultiHeadAttention`
mlp_dim: the intermediate dimensionality of the MLP head before
projecting to `project_dim`
num_heads: the number of heads for the `MultiHeadAttention` layer
mlp_dropout: default 0.1, the dropout rate to apply between the layers
of the MLP head of the encoder
attention_dropout: default 0.1, the dropout rate to apply in the
MultiHeadAttention layer
activation: default 'tf.activations.gelu', the activation function to
apply in the MLP head - should be a function
layer_norm_epsilon: default 1e-06, the epsilon for `LayerNormalization`
layers
Basic usage:
```
project_dim = 1024
mlp_dim = 3072
num_heads = 4
encoded_patches = keras_cv.layers.PatchingAndEmbedding(
project_dim=project_dim,
patch_size=16)(img_batch)
trans_encoded = keras_cv.layers.TransformerEncoder(project_dim=project_dim,
mlp_dim = mlp_dim,
num_heads=num_heads)(encoded_patches)
print(trans_encoded.shape) # (1, 197, 1024)
```
"""
def __init__(
self,
project_dim,
num_heads,
mlp_dim,
mlp_dropout=0.1,
attention_dropout=0.1,
activation=keras.activations.gelu,
layer_norm_epsilon=1e-06,
**kwargs,
):
super().__init__(**kwargs)
self.project_dim = project_dim
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.mlp_dropout = mlp_dropout
self.attention_dropout = attention_dropout
self.activation = activation
self.layer_norm_epsilon = layer_norm_epsilon
self.mlp_units = [mlp_dim, project_dim]
self.layer_norm1 = layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
)
self.layer_norm2 = layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
)
self.attn = layers.MultiHeadAttention(
num_heads=self.num_heads,
key_dim=self.project_dim // self.num_heads,
dropout=self.attention_dropout,
)
self.dense1 = layers.Dense(self.mlp_units[0])
self.dense2 = layers.Dense(self.mlp_units[1])
def call(self, inputs):
"""Calls the Transformer Encoder on an input sequence.
Args:
inputs: A `tf.Tensor` of shape [batch, height, width, channels]
Returns:
`A tf.Tensor` of shape [batch, patch_num+1, embedding_dim]
"""
if inputs.shape[-1] != self.project_dim:
raise ValueError(
"The input and output dimensionality must be the same, but the "
f"TransformerEncoder was provided with {inputs.shape[-1]} and "
f"{self.project_dim}"
)
x = self.layer_norm1(inputs)
x = self.attn(x, x)
x = layers.Dropout(self.mlp_dropout)(x)
x = layers.Add()([x, inputs])
y = self.layer_norm2(x)
y = self.dense1(y)
if self.activation == keras.activations.gelu:
y = self.activation(y, approximate=True)
else:
y = self.activation(y)
y = layers.Dropout(self.mlp_dropout)(y)
y = self.dense2(y)
y = layers.Dropout(self.mlp_dropout)(y)
output = layers.Add()([x, y])
return output
def get_config(self):
config = super().get_config()
activation = self.activation
if not isinstance(activation, (str, dict)):
activation = keras.activations.serialize(activation)
config.update(
{
"project_dim": self.project_dim,
"mlp_dim": self.mlp_dim,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
"mlp_dropout": self.mlp_dropout,
"activation": activation,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
activation = config.pop("activation")
if isinstance(activation, (str, dict)):
activation = keras.activations.deserialize(activation)
return cls(activation=activation, **config)