forked from keras-team/keras-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmit_layers.py
More file actions
293 lines (249 loc) · 9.5 KB
/
mit_layers.py
File metadata and controls
293 lines (249 loc) · 9.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import math
import keras
from keras import ops
from keras import random
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
"""Overlapping Patching and Embedding layer.
Differs from `PatchingAndEmbedding` in that the patch size does not
affect the sequence length. It's fully derived from the `stride`
parameter. Additionally, no positional embedding is done
as part of the layer - only a projection using a `Conv2D` layer.
Args:
project_dim: integer, the dimensionality of the projection.
Defaults to `32`.
patch_size: integer, the size of the patches to encode.
Defaults to `7`.
stride: integer, the stride to use for the patching before
projection. Defaults to `5`.
"""
super().__init__(**kwargs)
self.project_dim = project_dim
self.patch_size = patch_size
self.stride = stride
padding_size = self.patch_size // 2
self.padding = keras.layers.ZeroPadding2D(
padding=(padding_size, padding_size)
)
self.proj = keras.layers.Conv2D(
filters=project_dim,
kernel_size=patch_size,
strides=stride,
padding="valid",
)
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x):
x = self.padding(x)
x = self.proj(x)
x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3]))
x = self.norm(x)
return x
def get_config(self):
config = super().get_config()
config.update(
{
"project_dim": self.project_dim,
"patch_size": self.patch_size,
"stride": self.stride,
}
)
return config
class HierarchicalTransformerEncoder(keras.layers.Layer):
"""Hierarchical transformer encoder block implementation as a Keras Layer.
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
alternative for computational efficiency, and is meant to be used
within the SegFormer architecture.
Args:
project_dim: integer, the dimensionality of the projection of the
encoder, and output of the `SegFormerMultiheadAttention` layer.
Due to the residual addition the input dimensionality has to be
equal to the output dimensionality.
num_heads: integer, the number of heads for the
`SegFormerMultiheadAttention` layer.
drop_prob: float, the probability of dropping a random
sample using the `DropPath` layer. Defaults to `0.0`.
layer_norm_epsilon: float, the epsilon for
`LayerNormalization` layers. Defaults to `1e-06`
sr_ratio: integer, the ratio to use within
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
layer is used to reduce the length of the sequence. Defaults to `1`.
"""
def __init__(
self,
project_dim,
num_heads,
sr_ratio=1,
drop_prob=0.0,
layer_norm_epsilon=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.project_dim = project_dim
self.num_heads = num_heads
self.drop_prop = drop_prob
self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.attn = SegFormerMultiheadAttention(
project_dim, num_heads, sr_ratio
)
self.drop_path = DropPath(drop_prob)
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.mlp = MixFFN(
channels=project_dim,
mid_channels=int(project_dim * 4),
)
def build(self, input_shape):
super().build(input_shape)
self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))
def call(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def get_config(self):
config = super().get_config()
config.update(
{
"mlp": keras.saving.serialize_keras_object(self.mlp),
"project_dim": self.project_dim,
"num_heads": self.num_heads,
"drop_prop": self.drop_prop,
}
)
return config
class MixFFN(keras.layers.Layer):
def __init__(self, channels, mid_channels):
super().__init__()
self.fc1 = keras.layers.Dense(mid_channels)
self.dwconv = keras.layers.DepthwiseConv2D(
kernel_size=3,
strides=1,
padding="same",
)
self.fc2 = keras.layers.Dense(channels)
def call(self, x):
x = self.fc1(x)
shape = ops.shape(x)
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
B, C = shape[0], shape[2]
x = ops.reshape(x, (B, H, W, C))
x = self.dwconv(x)
x = ops.reshape(x, (B, -1, C))
x = ops.nn.gelu(x)
x = self.fc2(x)
return x
class SegFormerMultiheadAttention(keras.layers.Layer):
def __init__(self, project_dim, num_heads, sr_ratio):
"""Efficient MultiHeadAttention implementation as a Keras layer.
A huge bottleneck in scaling transformers is the self-attention layer
with an O(n^2) complexity.
SegFormerMultiheadAttention performs a sequence reduction (SR) operation
with a given ratio, to reduce the sequence length before performing key
and value projections, reducing the O(n^2) complexity to O(n^2/R) where
R is the sequence reduction ratio.
Args:
project_dim: integer, the dimensionality of the projection
of the `SegFormerMultiheadAttention` layer.
num_heads: integer, the number of heads to use in the
attention computation.
sr_ratio: integer, the sequence reduction ratio to perform
on the sequence before key and value projections.
"""
super().__init__()
self.num_heads = num_heads
self.sr_ratio = sr_ratio
self.scale = (project_dim // num_heads) ** -0.5
self.q = keras.layers.Dense(project_dim)
self.k = keras.layers.Dense(project_dim)
self.v = keras.layers.Dense(project_dim)
self.proj = keras.layers.Dense(project_dim)
self.dropout = keras.layers.Dropout(0.1)
self.proj_drop = keras.layers.Dropout(0.1)
if sr_ratio > 1:
self.sr = keras.layers.Conv2D(
filters=project_dim,
kernel_size=sr_ratio,
strides=sr_ratio,
)
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x):
input_shape = ops.shape(x)
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
B, N, C = input_shape[0], input_shape[1], input_shape[2]
q = self.q(x)
q = ops.reshape(
q,
(
input_shape[0],
input_shape[1],
self.num_heads,
input_shape[2] // self.num_heads,
),
)
q = ops.transpose(q, [0, 2, 1, 3])
if self.sr_ratio > 1:
x = ops.reshape(
x,
(B, H, W, C),
)
x = self.sr(x)
x = ops.reshape(x, [B, -1, C])
x = self.norm(x)
k = self.k(x)
v = self.v(x)
k = ops.transpose(
ops.reshape(
k,
[B, -1, self.num_heads, C // self.num_heads],
),
[0, 2, 1, 3],
)
v = ops.transpose(
ops.reshape(
v,
[B, -1, self.num_heads, C // self.num_heads],
),
[0, 2, 1, 3],
)
attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
attn = ops.nn.softmax(attn, axis=-1)
attn = self.dropout(attn)
attn = attn @ v
attn = ops.reshape(
ops.transpose(attn, [0, 2, 1, 3]),
[B, N, C],
)
x = self.proj(attn)
x = self.proj_drop(x)
return x
class DropPath(keras.layers.Layer):
"""Implements the DropPath layer.
DropPath randomly drops samples during
training with a probability of `rate`. Note that this layer drops individual
samples within a batch and not the entire batch, whereas StochasticDepth
randomly drops the entire batch.
Args:
rate: float, the probability of the residual branch being dropped.
seed: (Optional) integer. Used to create a random seed.
"""
def __init__(self, rate=0.5, seed=None, **kwargs):
super().__init__(**kwargs)
self.rate = rate
self._seed_val = seed
self.seed = random.SeedGenerator(seed=seed)
def call(self, x, training=None):
if self.rate == 0.0 or not training:
return x
else:
batch_size = x.shape[0] or ops.shape(x)[0]
drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
drop_map = ops.cast(
random.uniform(drop_map_shape, seed=self.seed) > self.rate,
x.dtype,
)
x = x / (1.0 - self.rate)
x = x * drop_map
return x
def get_config(self):
config = super().get_config()
config.update({"rate": self.rate, "seed": self._seed_val})
return config