This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
[DeepVision Port] SegFormer and Mix-Transformers #1946
Merged
Merged
Changes from 46 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
dc41892
initial dump
DavidLandup0 e5677e6
add all basic layers, port roughly to keras core ops
DavidLandup0 7bd1056
updated .gitignore
DavidLandup0 03470df
segformer head and formatting
DavidLandup0 cb1c702
cleanup
DavidLandup0 22f8fdf
remove tf call
DavidLandup0 5c9803a
remove tf
DavidLandup0 314dc6b
migrating to more keras ops
DavidLandup0 7a0151b
cleanups and fixes
DavidLandup0 44f01af
fix reshaping
DavidLandup0 eb5b5ae
comments
DavidLandup0 ea0239f
from presets api, keras.ops -> ops
DavidLandup0 b6128a5
embed_dims -> embedding_dims
DavidLandup0 8322109
addressing some PR comments
DavidLandup0 75bb4a2
docstrings, argument update
DavidLandup0 97daf7c
depths arg
DavidLandup0 5f9dc0c
sync
DavidLandup0 efbbd49
compute output shapes
DavidLandup0 d3b43c6
segformer progress
DavidLandup0 dab4e74
head
DavidLandup0 1dba059
softmax
DavidLandup0 bdc3687
remove softmax
DavidLandup0 ddfa315
undo compute_output_shapes()
DavidLandup0 5a091b6
efficientmultiheadattention -> segformermultiheadattention
DavidLandup0 4e9df16
docstrings
DavidLandup0 278875c
softmax output
DavidLandup0 884c376
Merge branch 'master' into segformer_tf
DavidLandup0 6618a65
segformer presets
DavidLandup0 e1fbdb0
Merge branch 'segformer_tf' of https://github.com/DavidLandup0/keras-…
DavidLandup0 00ecd92
updating segformer presets
DavidLandup0 97d9d4a
segformer presets
DavidLandup0 c10963f
import aliases
DavidLandup0 f882b3e
Merge branch 'master' into segformer_tf
DavidLandup0 ab10136
refactoring
DavidLandup0 094189e
pr comments
DavidLandup0 a4df0a6
pr comments
DavidLandup0 e22a15e
add aliases
DavidLandup0 5d63d18
aliases ot init
DavidLandup0 03a177f
refactor fix
DavidLandup0 d1cdd5d
import keras_cv_export
DavidLandup0 ff32d63
fix presets/aliases and add copyright
DavidLandup0 5f3fc22
linter warnings
DavidLandup0 c6b454f
linter errors
DavidLandup0 5ac7f77
consistency in presets
DavidLandup0 b2a76ce
return config
DavidLandup0 0ad5879
fix serialization
DavidLandup0 eea5e3c
Some cleanup + more tests
ianstenbit 8e62cf6
Fix DropPath layer (need to update tests + add shim for tf.keras
ianstenbit b9efeb1
Finish DropPath layer
ianstenbit bd5a99f
Use static shape in backbone
ianstenbit 3d29b0a
Formatting
ianstenbit 4e2c4e8
Switch back to ops.shape
ianstenbit b32e0cf
documentation
DavidLandup0 743a3bb
documentation
DavidLandup0 c640fc9
remove default num classes
DavidLandup0 f1b5ffa
fix docs
DavidLandup0 e32704b
Merge branch 'master' into segformer_tf
ianstenbit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,3 +16,4 @@ __pycache__/ | |
| .vscode/ | ||
| .devcontainer/ | ||
| .coverage | ||
| .history | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Copyright 2023 The KerasCV Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from keras_cv.api_export import keras_cv_export | ||
| from keras_cv.backend import keras | ||
| from keras_cv.backend import ops | ||
| from keras_cv.layers.regularization.drop_path import DropPath | ||
| from keras_cv.layers.segformer_multihead_attention import ( | ||
| SegFormerMultiheadAttention, | ||
| ) | ||
|
|
||
|
|
||
| @keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder") | ||
| class HierarchicalTransformerEncoder(keras.layers.Layer): | ||
| """ | ||
| Hierarchical transformer encoder block implementation as a Keras Layer. | ||
| The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` | ||
| alternative for computational efficiency, and is meant to be used | ||
| within the SegFormer architecture. | ||
|
|
||
| References: | ||
| - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 | ||
| - [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 | ||
| - [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 | ||
|
|
||
| Args: | ||
| project_dim: the dimensionality of the projection of the encoder, and | ||
| output of the `SegFormerMultiheadAttention` layer. Due to the | ||
| residual addition the input dimensionality has to be equal to | ||
| the output dimensionality. | ||
| num_heads: the number of heads for the `SegFormerMultiheadAttention` | ||
| layer | ||
| drop_prob: default 0.0, the probability of dropping a random sample | ||
DavidLandup0 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| using the `DropPath` layer. | ||
| layer_norm_epsilon: default 1e-06, the epsilon for `LayerNormalization` | ||
| layers | ||
| sr_ratio: default 1, the ratio to use within `SegFormerMultiheadAttention`. # noqa: E501 | ||
| If set to > 1, a `Conv2D` layer is used to reduce the length of | ||
| the sequence. | ||
|
|
||
| Basic usage: | ||
|
|
||
| ``` | ||
| project_dim = 1024 | ||
| num_heads = 4 | ||
| patch_size = 16 | ||
|
|
||
| encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding( | ||
| project_dim=project_dim, patch_size=patch_size)(img_batch) | ||
|
|
||
| trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim, | ||
| num_heads=num_heads, | ||
| sr_ratio=1)(encoded_patches) | ||
|
|
||
| print(trans_encoded.shape) # (1, 3136, 1024) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| project_dim, | ||
| num_heads, | ||
| sr_ratio=1, | ||
| drop_prob=0.0, | ||
| layer_norm_epsilon=1e-6, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.project_dim = project_dim | ||
| self.num_heads = num_heads | ||
| self.drop_prop = drop_prob | ||
|
|
||
| self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) | ||
| self.attn = SegFormerMultiheadAttention( | ||
| project_dim, num_heads, sr_ratio | ||
| ) | ||
| self.drop_path = DropPath(drop_prob) | ||
| self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) | ||
| self.mlp = self.MixFFN( | ||
| channels=project_dim, | ||
| mid_channels=int(project_dim * 4), | ||
| ) | ||
|
|
||
| def build(self, input_shape): | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| super().build(input_shape) | ||
| self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) | ||
| self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) | ||
|
|
||
| def call(self, x): | ||
| x = x + self.drop_path(self.attn(self.norm1(x))) | ||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
| return x | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "mlp": keras.saving.serialize_keras_object(self.mlp), | ||
| "project_dim": self.project_dim, | ||
| "num_heads": self.num_heads, | ||
| "drop_prop": self.drop_prop, | ||
| } | ||
| ) | ||
| return config | ||
|
|
||
| class MixFFN(keras.layers.Layer): | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def __init__(self, channels, mid_channels): | ||
| super().__init__() | ||
| self.fc1 = keras.layers.Dense(mid_channels) | ||
| self.dwconv = keras.layers.DepthwiseConv2D( | ||
| kernel_size=3, | ||
| strides=1, | ||
| padding="same", | ||
| ) | ||
| self.fc2 = keras.layers.Dense(channels) | ||
|
|
||
| def call(self, x): | ||
| x = self.fc1(x) | ||
| shape = ops.shape(x) | ||
| B, C = ops.cast(shape[0], "float32"), ops.cast(shape[-1], "float32") | ||
| H, W = ops.sqrt(ops.cast(shape[1], "float32")), ops.sqrt( | ||
| ops.cast(shape[1], "float32") | ||
| ) | ||
| x = ops.reshape(x, (B, H, W, C)) | ||
| x = self.dwconv(x) | ||
| x = ops.reshape(x, (B, -1, C)) | ||
| x = ops.nn.gelu(x) | ||
| x = self.fc2(x) | ||
| return x | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Copyright 2023 The KerasCV Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from keras_cv.api_export import keras_cv_export | ||
| from keras_cv.backend import keras | ||
| from keras_cv.backend import ops | ||
|
|
||
|
|
||
| @keras_cv_export("keras_cv.layers.OverlappingPatchingAndEmbedding") | ||
| class OverlappingPatchingAndEmbedding(keras.layers.Layer): | ||
| def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): | ||
| """ | ||
| Overlapping Patching and Embedding layer. Differs from `PatchingAndEmbedding` | ||
| in that the patch size does not affect the sequence length. It's fully derived | ||
| from the `stride` parameter. Additionally, no positional embedding is done | ||
| as part of the layer - only a projection using a `Conv2D` layer. | ||
|
|
||
| References: | ||
| - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 | ||
| - [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 | ||
| - [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 | ||
|
|
||
| Args: | ||
| project_dim: the dimensionality of the projection of the encoder, and | ||
| output of the `MultiHeadAttention` | ||
| num_heads: the number of heads for the `MultiHeadAttention` layer | ||
| drop_prob: default 0.0, the probability of dropping a random sample using the `DropPath` layer. | ||
| layer_norm_epsilon: default 1e-06, the epsilon for `LayerNormalization` | ||
| layers | ||
| sr_ratio: default 1, the ratio to use within `SegFormerMultiheadAttention`. If set to > 1, | ||
| a `Conv2D` layer is used to reduce the length of the sequence. | ||
|
|
||
| Basic usage: | ||
|
|
||
| ``` | ||
| project_dim = 1024 | ||
| patch_size = 16 | ||
|
|
||
| encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding( | ||
| project_dim=project_dim, patch_size=patch_size)(img_batch) | ||
|
|
||
| print(encoded_patches.shape) # (1, 3136, 1024) | ||
| ``` | ||
| """ | ||
| super().__init__(**kwargs) | ||
|
|
||
| self.project_dim = project_dim | ||
| self.patch_size = patch_size | ||
| self.stride = stride | ||
|
|
||
| self.proj = keras.layers.Conv2D( | ||
| filters=project_dim, | ||
| kernel_size=patch_size, | ||
| strides=stride, | ||
| padding="same", | ||
| ) | ||
| self.norm = keras.layers.LayerNormalization() | ||
|
|
||
| def call(self, x): | ||
| x = self.proj(x) | ||
| # B, H, W, C | ||
| shape = x.shape | ||
| x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) | ||
| x = self.norm(x) | ||
| return x | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "project_dim": self.project_dim, | ||
| "patch_size": self.patch_size, | ||
| "stride": self.stride, | ||
| } | ||
| ) | ||
| return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright 2023 The KerasCV Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from keras_cv.api_export import keras_cv_export | ||
| from keras_cv.backend import keras | ||
| from keras_cv.backend import ops | ||
|
|
||
|
|
||
| @keras_cv_export("keras_cv.layers.SegFormerMultiheadAttention") | ||
| class SegFormerMultiheadAttention(keras.layers.Layer): | ||
| def __init__(self, project_dim, num_heads, sr_ratio): | ||
| """ | ||
| Efficient MultiHeadAttention implementation as a Keras layer. | ||
| A huge bottleneck in scaling transformers is the self-attention layer | ||
| with an O(n^2) complexity. | ||
|
|
||
| SegFormerMultiheadAttention performs a sequence reduction (SR) operation | ||
| with a given ratio, to reduce the sequence length before performing key and value projections, | ||
| reducing the O(n^2) complexity to O(n^2/R) where R is the sequence reduction ratio. | ||
|
|
||
| References: | ||
| - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 | ||
| - [NVlabs' official implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 | ||
| - [@sithu31296's reimplementation](https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py) # noqa: E501 | ||
| - [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/efficient_attention.py) # noqa: E501 | ||
|
|
||
| Args: | ||
| project_dim: the dimensionality of the projection of the `SegFormerMultiheadAttention` layer. | ||
| num_heads: the number of heads to use in the attention computation. | ||
| sr_ratio: the sequence reduction ratio to perform on the sequence before key and value projections. | ||
|
|
||
| Basic usage: | ||
|
|
||
| ``` | ||
| tensor = tf.random.uniform([1, 196, 32]) | ||
| output = keras_cv.layers.SegFormerMultiheadAttention(project_dim=768, | ||
| num_heads=2, | ||
| sr_ratio=4)(tensor) | ||
| print(output.shape) # (1, 196, 32) | ||
| ``` | ||
| """ | ||
| super().__init__() | ||
| self.num_heads = num_heads | ||
| self.sr_ratio = sr_ratio | ||
| self.scale = (project_dim // num_heads) ** -0.5 | ||
| self.q = keras.layers.Dense(project_dim) | ||
| self.k = keras.layers.Dense(project_dim) | ||
| self.v = keras.layers.Dense(project_dim) | ||
| self.proj = keras.layers.Dense(project_dim) | ||
|
|
||
| if sr_ratio > 1: | ||
| self.sr = keras.layers.Conv2D( | ||
| filters=project_dim, | ||
| kernel_size=sr_ratio, | ||
| strides=sr_ratio, | ||
| padding="same", | ||
| ) | ||
| self.norm = keras.layers.LayerNormalization() | ||
|
|
||
| def call(self, x): | ||
| input_shape = ops.shape(x) | ||
| H, W = ops.sqrt(ops.cast(input_shape[1], "float32")), ops.sqrt( | ||
| ops.cast(input_shape[1], "float32") | ||
| ) | ||
| B, C = ops.cast(input_shape[0], "float32"), ops.cast( | ||
| input_shape[2], "float32" | ||
| ) | ||
|
|
||
| q = self.q(x) | ||
| q = ops.reshape( | ||
| q, | ||
| ( | ||
| input_shape[0], | ||
| input_shape[1], | ||
| self.num_heads, | ||
| input_shape[2] // self.num_heads, | ||
| ), | ||
| ) | ||
| q = ops.transpose(q, [0, 2, 1, 3]) | ||
|
|
||
| if self.sr_ratio > 1: | ||
| x = ops.reshape( | ||
| ops.transpose(x, [0, 2, 1]), | ||
| (B, H, W, C), | ||
| ) | ||
| x = self.sr(x) | ||
| x = ops.reshape(x, [input_shape[0], input_shape[2], -1]) | ||
| x = ops.transpose(x, [0, 2, 1]) | ||
| x = self.norm(x) | ||
|
|
||
| k = self.k(x) | ||
| v = self.v(x) | ||
|
|
||
| k = ops.transpose( | ||
| ops.reshape( | ||
| k, | ||
| [B, -1, self.num_heads, C // self.num_heads], | ||
| ), | ||
| [0, 2, 1, 3], | ||
| ) | ||
|
|
||
| v = ops.transpose( | ||
| ops.reshape( | ||
| v, | ||
| [B, -1, self.num_heads, C // self.num_heads], | ||
| ), | ||
| [0, 2, 1, 3], | ||
| ) | ||
|
|
||
| attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale | ||
| attn = ops.nn.softmax(attn, axis=-1) | ||
|
|
||
| attn = attn @ v | ||
| attn = ops.reshape( | ||
| ops.transpose(attn, [0, 2, 1, 3]), | ||
| [input_shape[0], input_shape[1], input_shape[2]], | ||
| ) | ||
| x = self.proj(attn) | ||
| return x |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.