Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.

Commit ab812d1

Browse files
[DeepVision Port] SegFormer and Mix-Transformers (#1946)
* initial dump * add all basic layers, port roughly to keras core ops * updated .gitignore * segformer head and formatting * cleanup * remove tf call * remove tf * migrating to more keras ops * cleanups and fixes * fix reshaping * comments * from presets api, keras.ops -> ops * embed_dims -> embedding_dims * addressing some PR comments * docstrings, argument update * depths arg * sync * compute output shapes * segformer progress * head * softmax * remove softmax * undo compute_output_shapes() * efficientmultiheadattention -> segformermultiheadattention * docstrings * softmax output * segformer presets * updating segformer presets * segformer presets * import aliases * refactoring * pr comments * pr comments * add aliases * aliases ot init * refactor fix * import keras_cv_export * fix presets/aliases and add copyright * linter warnings * linter errors * consistency in presets * return config * fix serialization * Some cleanup + more tests * Fix DropPath layer (need to update tests + add shim for tf.keras * Finish DropPath layer * Use static shape in backbone * Formatting * Switch back to ops.shape * documentation * documentation * remove default num classes * fix docs --------- Co-authored-by: ianjjohnson <3072903+ianstenbit@users.noreply.github.com>
1 parent b038f58 commit ab812d1

22 files changed

+1855
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ __pycache__/
1616
.vscode/
1717
.devcontainer/
1818
.coverage
19+
.history

keras_cv/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676

7777
from keras_cv.backend import config # noqa: E402
7878
from keras_cv.backend import ops # noqa: E402
79+
from keras_cv.backend import random # noqa: E402
7980
from keras_cv.backend import tf_ops # noqa: E402
8081

8182

keras_cv/backend/random.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras_cv.backend.config import multi_backend
16+
17+
if multi_backend():
18+
from keras_core.random import * # noqa: F403, F401
19+
else:
20+
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401

keras_cv/layers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from keras_cv.layers.augmenter import Augmenter
2020
from keras_cv.layers.feature_pyramid import FeaturePyramid
2121
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
22+
from keras_cv.layers.hierarchical_transformer_encoder import (
23+
HierarchicalTransformerEncoder,
24+
)
2225
from keras_cv.layers.mbconv import MBConvBlock
2326
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
2427
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
@@ -32,6 +35,9 @@
3235
CenterNetLabelEncoder,
3336
)
3437
from keras_cv.layers.object_detection_3d.voxelization import DynamicVoxelization
38+
from keras_cv.layers.overlapping_patching_embedding import (
39+
OverlappingPatchingAndEmbedding,
40+
)
3541
from keras_cv.layers.preprocessing.aug_mix import AugMix
3642
from keras_cv.layers.preprocessing.auto_contrast import AutoContrast
3743
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
@@ -124,6 +130,9 @@
124130
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
125131
from keras_cv.layers.regularization.squeeze_excite import SqueezeAndExcite2D
126132
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth
133+
from keras_cv.layers.segformer_multihead_attention import (
134+
SegFormerMultiheadAttention,
135+
)
127136
from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling
128137
from keras_cv.layers.transformer_encoder import TransformerEncoder
129138
from keras_cv.layers.vit_layers import PatchingAndEmbedding
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
17+
from keras_cv.api_export import keras_cv_export
18+
from keras_cv.backend import keras
19+
from keras_cv.backend import ops
20+
from keras_cv.layers.regularization.drop_path import DropPath
21+
from keras_cv.layers.segformer_multihead_attention import (
22+
SegFormerMultiheadAttention,
23+
)
24+
25+
26+
@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder")
27+
class HierarchicalTransformerEncoder(keras.layers.Layer):
28+
"""
29+
Hierarchical transformer encoder block implementation as a Keras Layer.
30+
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
31+
alternative for computational efficiency, and is meant to be used
32+
within the SegFormer architecture.
33+
34+
References:
35+
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
36+
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
37+
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501
38+
39+
Args:
40+
project_dim: integer, the dimensionality of the projection of the
41+
encoder, and output of the `SegFormerMultiheadAttention` layer.
42+
Due to the residual addition the input dimensionality has to be
43+
equal to the output dimensionality.
44+
num_heads: integer, the number of heads for the
45+
`SegFormerMultiheadAttention` layer.
46+
drop_prob: float, the probability of dropping a random
47+
sample using the `DropPath` layer. Defaults to `0.0`.
48+
layer_norm_epsilon: float, the epsilon for
49+
`LayerNormalization` layers. Defaults to `1e-06`
50+
sr_ratio: integer, the ratio to use within
51+
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
52+
layer is used to reduce the length of the sequence. Defaults to `1`.
53+
54+
Basic usage:
55+
56+
```
57+
project_dim = 1024
58+
num_heads = 4
59+
patch_size = 16
60+
61+
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
62+
project_dim=project_dim, patch_size=patch_size)(img_batch)
63+
64+
trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim,
65+
num_heads=num_heads,
66+
sr_ratio=1)(encoded_patches)
67+
68+
print(trans_encoded.shape) # (1, 3136, 1024)
69+
```
70+
"""
71+
72+
def __init__(
73+
self,
74+
project_dim,
75+
num_heads,
76+
sr_ratio=1,
77+
drop_prob=0.0,
78+
layer_norm_epsilon=1e-6,
79+
**kwargs,
80+
):
81+
super().__init__(**kwargs)
82+
self.project_dim = project_dim
83+
self.num_heads = num_heads
84+
self.drop_prop = drop_prob
85+
86+
self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
87+
self.attn = SegFormerMultiheadAttention(
88+
project_dim, num_heads, sr_ratio
89+
)
90+
self.drop_path = DropPath(drop_prob)
91+
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
92+
self.mlp = self.MixFFN(
93+
channels=project_dim,
94+
mid_channels=int(project_dim * 4),
95+
)
96+
97+
def build(self, input_shape):
98+
super().build(input_shape)
99+
self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
100+
self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))
101+
102+
def call(self, x):
103+
x = x + self.drop_path(self.attn(self.norm1(x)))
104+
x = x + self.drop_path(self.mlp(self.norm2(x)))
105+
return x
106+
107+
def get_config(self):
108+
config = super().get_config()
109+
config.update(
110+
{
111+
"mlp": keras.saving.serialize_keras_object(self.mlp),
112+
"project_dim": self.project_dim,
113+
"num_heads": self.num_heads,
114+
"drop_prop": self.drop_prop,
115+
}
116+
)
117+
return config
118+
119+
class MixFFN(keras.layers.Layer):
120+
def __init__(self, channels, mid_channels):
121+
super().__init__()
122+
self.fc1 = keras.layers.Dense(mid_channels)
123+
self.dwconv = keras.layers.DepthwiseConv2D(
124+
kernel_size=3,
125+
strides=1,
126+
padding="same",
127+
)
128+
self.fc2 = keras.layers.Dense(channels)
129+
130+
def call(self, x):
131+
x = self.fc1(x)
132+
shape = ops.shape(x)
133+
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
134+
B, C = shape[0], shape[2]
135+
x = ops.reshape(x, (B, H, W, C))
136+
x = self.dwconv(x)
137+
x = ops.reshape(x, (B, -1, C))
138+
x = ops.nn.gelu(x)
139+
x = self.fc2(x)
140+
return x
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2023 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from keras_cv.api_export import keras_cv_export
16+
from keras_cv.backend import keras
17+
from keras_cv.backend import ops
18+
19+
20+
@keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding")
21+
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
22+
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
23+
"""
24+
Overlapping Patching and Embedding layer. Differs from `PatchingAndEmbedding`
25+
in that the patch size does not affect the sequence length. It's fully derived
26+
from the `stride` parameter. Additionally, no positional embedding is done
27+
as part of the layer - only a projection using a `Conv2D` layer.
28+
29+
References:
30+
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501
31+
- [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501
32+
- [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501
33+
34+
Args:
35+
project_dim: integer, the dimensionality of the projection.
36+
Defaults to `32`.
37+
patch_size: integer, the size of the patches to encode.
38+
Defaults to `7`.
39+
stride: integer, the stride to use for the patching before
40+
projection. Defaults to `5`.
41+
42+
Basic usage:
43+
44+
```
45+
project_dim = 1024
46+
patch_size = 16
47+
48+
encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding(
49+
project_dim=project_dim, patch_size=patch_size)(img_batch)
50+
51+
print(encoded_patches.shape) # (1, 3136, 1024)
52+
```
53+
"""
54+
super().__init__(**kwargs)
55+
56+
self.project_dim = project_dim
57+
self.patch_size = patch_size
58+
self.stride = stride
59+
60+
self.proj = keras.layers.Conv2D(
61+
filters=project_dim,
62+
kernel_size=patch_size,
63+
strides=stride,
64+
padding="same",
65+
)
66+
self.norm = keras.layers.LayerNormalization()
67+
68+
def call(self, x):
69+
x = self.proj(x)
70+
# B, H, W, C
71+
shape = x.shape
72+
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
73+
x = self.norm(x)
74+
return x
75+
76+
def get_config(self):
77+
config = super().get_config()
78+
config.update(
79+
{
80+
"project_dim": self.project_dim,
81+
"patch_size": self.patch_size,
82+
"stride": self.stride,
83+
}
84+
)
85+
return config

keras_cv/layers/regularization/drop_path.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from tensorflow import keras
16-
1715
from keras_cv.api_export import keras_cv_export
16+
from keras_cv.backend import keras
17+
from keras_cv.backend import ops
18+
from keras_cv.backend import random
1819

1920

2021
@keras_cv_export("keras_cv.layers.DropPath")
21-
class DropPath(keras.__internal__.layers.BaseRandomLayer):
22+
class DropPath(keras.layers.Layer):
2223
"""
2324
Implements the DropPath layer. DropPath randomly drops samples during
2425
training with a probability of `rate`. Note that this layer drops individual
@@ -47,20 +48,21 @@ class DropPath(keras.__internal__.layers.BaseRandomLayer):
4748
""" # noqa: E501
4849

4950
def __init__(self, rate=0.5, seed=None, **kwargs):
50-
super().__init__(seed=seed, **kwargs)
51+
super().__init__(**kwargs)
5152
self.rate = rate
5253
self.seed = seed
5354

5455
def call(self, x, training=None):
5556
if self.rate == 0.0 or not training:
5657
return x
5758
else:
58-
keep_prob = 1 - self.rate
59-
drop_map_shape = (x.shape[0],) + (1,) * (len(x.shape) - 1)
60-
drop_map = keras.backend.random_bernoulli(
61-
drop_map_shape, p=keep_prob, seed=self.seed
59+
batch_size = x.shape[0] or ops.shape(x)[0]
60+
drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
61+
drop_map = ops.cast(
62+
random.uniform(drop_map_shape, seed=self.seed) > self.rate,
63+
x.dtype,
6264
)
63-
x = x / keep_prob
65+
x = x / (1.0 - self.rate)
6466
x = x * drop_map
6567
return x
6668

keras_cv/layers/regularization/drop_path_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
16+
import pytest
1517
import tensorflow as tf
1618

1719
from keras_cv.layers import DropPath
@@ -23,23 +25,23 @@ class DropPathTest(TestCase):
2325

2426
def test_input_unchanged_in_eval_mode(self):
2527
layer = DropPath(rate=0.5, seed=42)
26-
inputs = tf.random.uniform(self.FEATURE_SHAPE)
28+
inputs = np.random.uniform(size=self.FEATURE_SHAPE)
2729

2830
outputs = layer(inputs, training=False)
2931

3032
self.assertAllClose(inputs, outputs)
3133

3234
def test_input_unchanged_with_rate_equal_to_zero(self):
3335
layer = DropPath(rate=0, seed=42)
34-
inputs = tf.random.uniform(self.FEATURE_SHAPE)
36+
inputs = np.random.uniform(size=self.FEATURE_SHAPE)
3537

3638
outputs = layer(inputs, training=True)
3739

3840
self.assertAllClose(inputs, outputs)
3941

4042
def test_input_gets_partially_zeroed_out_in_train_mode(self):
4143
layer = DropPath(rate=0.2, seed=42)
42-
inputs = tf.random.uniform(self.FEATURE_SHAPE)
44+
inputs = np.random.uniform(size=self.FEATURE_SHAPE)
4345

4446
outputs = layer(inputs, training=True)
4547

@@ -48,9 +50,11 @@ def test_input_gets_partially_zeroed_out_in_train_mode(self):
4850

4951
self.assertGreaterEqual(non_zeros_inputs, non_zeros_outputs)
5052

53+
# Because randomness is inconsistent across backends, we just test with 1.
54+
@pytest.mark.tf_keras_only
5155
def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):
52-
layer = DropPath(rate=0.5, seed=42)
53-
inputs = tf.random.uniform(self.FEATURE_SHAPE)
56+
layer = DropPath(rate=0.5, seed=10)
57+
inputs = np.random.uniform(size=self.FEATURE_SHAPE)
5458

5559
total_non_zero_inputs = 0
5660
total_non_zero_outputs = 0
@@ -66,6 +70,6 @@ def test_strict_input_gets_partially_zeroed_out_in_train_mode(self):
6670

6771
self.assertAllInRange(
6872
total_non_zero_outputs,
69-
int(0.49 * tf.cast(total_non_zero_inputs, tf.float32)),
70-
int(0.51 * tf.cast(total_non_zero_inputs, tf.float32)),
73+
int(0.40 * tf.cast(total_non_zero_inputs, tf.float32)),
74+
int(0.60 * tf.cast(total_non_zero_inputs, tf.float32)),
7175
)

0 commit comments

Comments
 (0)