forked from keras-team/keras-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmit_backbone.py
More file actions
167 lines (149 loc) · 6.39 KB
/
mit_backbone.py
File metadata and controls
167 lines (149 loc) · 6.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
import numpy as np
from keras import ops
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_hub.src.models.mit.mit_layers import HierarchicalTransformerEncoder
from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding
@keras_hub_export("keras_hub.models.MiTBackbone")
class MiTBackbone(FeaturePyramidBackbone):
def __init__(
self,
depths,
num_layers,
blockwise_num_heads,
blockwise_sr_ratios,
max_drop_path_rate,
patch_sizes,
strides,
image_shape=(None, None, 3),
hidden_dims=None,
**kwargs,
):
"""A Backbone implementing the MixTransformer.
This architecture to be used as a backbone for the SegFormer
architecture [SegFormer: Simple and Efficient Design for Semantic
Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
[Based on the TensorFlow implementation from DeepVision](
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
Args:
depths: The number of transformer encoders to be used per layer in the
network.
num_layers: int. The number of Transformer layers.
blockwise_num_heads: list of integers, the number of heads to use
in the attention computation for each layer.
blockwise_sr_ratios: list of integers, the sequence reduction
ratio to perform for each layer on the sequence before key and
value projections. If set to > 1, a `Conv2D` layer is used to
reduce the length of the sequence.
max_drop_path_rate: The final value of the `linspace()` that
defines the drop path rates for the `DropPath` layers of
the `HierarchicalTransformerEncoder` layers.
image_shape: optional shape tuple, defaults to (None, None, 3).
hidden_dims: the embedding dims per hierarchical layer, used as
the levels of the feature pyramid.
patch_sizes: list of integers, the patch_size to apply for each layer.
strides: list of integers, stride to apply for each layer.
Examples:
Using the class with a `backbone`:
```python
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
# Evaluate model
model(images)
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
```
"""
dpr = [x for x in np.linspace(0.0, max_drop_path_rate, sum(depths))]
# === Layers ===
cur = 0
patch_embedding_layers = []
transformer_blocks = []
layer_norms = []
for i in range(num_layers):
patch_embed_layer = OverlappingPatchingAndEmbedding(
project_dim=hidden_dims[i],
patch_size=patch_sizes[i],
stride=strides[i],
name=f"patch_and_embed_{i}",
)
patch_embedding_layers.append(patch_embed_layer)
transformer_block = [
HierarchicalTransformerEncoder(
project_dim=hidden_dims[i],
num_heads=blockwise_num_heads[i],
sr_ratio=blockwise_sr_ratios[i],
drop_prob=dpr[cur + k],
name=f"hierarchical_encoder_{i}_{k}",
)
for k in range(depths[i])
]
transformer_blocks.append(transformer_block)
cur += depths[i]
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
# === Functional Model ===
image_input = keras.layers.Input(shape=image_shape)
x = image_input # Intermediate result.
pyramid_outputs = {}
for i in range(num_layers):
# Compute new height/width after the `proj`
# call in `OverlappingPatchingAndEmbedding`
stride = strides[i]
new_height, new_width = (
int(ops.shape(x)[1] / stride),
int(ops.shape(x)[2] / stride),
)
x = patch_embedding_layers[i](x)
for blk in transformer_blocks[i]:
x = blk(x)
x = layer_norms[i](x)
x = keras.layers.Reshape(
(new_height, new_width, -1), name=f"output_level_{i}"
)(x)
pyramid_outputs[f"P{i + 1}"] = x
super().__init__(inputs=image_input, outputs=x, **kwargs)
# === Config ===
self.depths = depths
self.image_shape = image_shape
self.hidden_dims = hidden_dims
self.pyramid_outputs = pyramid_outputs
self.num_layers = num_layers
self.blockwise_num_heads = blockwise_num_heads
self.blockwise_sr_ratios = blockwise_sr_ratios
self.max_drop_path_rate = max_drop_path_rate
self.patch_sizes = patch_sizes
self.strides = strides
def get_config(self):
config = super().get_config()
config.update(
{
"depths": self.depths,
"hidden_dims": self.hidden_dims,
"image_shape": self.image_shape,
"num_layers": self.num_layers,
"blockwise_num_heads": self.blockwise_num_heads,
"blockwise_sr_ratios": self.blockwise_sr_ratios,
"max_drop_path_rate": self.max_drop_path_rate,
"patch_sizes": self.patch_sizes,
"strides": self.strides,
}
)
return config