Skip to content
This repository was archived by the owner on Mar 10, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
dc41892
initial dump
DavidLandup0 Jul 13, 2023
e5677e6
add all basic layers, port roughly to keras core ops
DavidLandup0 Jul 17, 2023
7bd1056
updated .gitignore
DavidLandup0 Jul 17, 2023
03470df
segformer head and formatting
DavidLandup0 Jul 17, 2023
cb1c702
cleanup
DavidLandup0 Jul 17, 2023
22f8fdf
remove tf call
DavidLandup0 Jul 17, 2023
5c9803a
remove tf
DavidLandup0 Jul 17, 2023
314dc6b
migrating to more keras ops
DavidLandup0 Jul 17, 2023
7a0151b
cleanups and fixes
DavidLandup0 Jul 23, 2023
44f01af
fix reshaping
DavidLandup0 Jul 23, 2023
eb5b5ae
comments
DavidLandup0 Jul 23, 2023
ea0239f
from presets api, keras.ops -> ops
DavidLandup0 Jul 23, 2023
b6128a5
embed_dims -> embedding_dims
DavidLandup0 Jul 23, 2023
8322109
addressing some PR comments
DavidLandup0 Jul 24, 2023
75bb4a2
docstrings, argument update
DavidLandup0 Jul 24, 2023
97daf7c
depths arg
DavidLandup0 Jul 24, 2023
5f9dc0c
sync
DavidLandup0 Jul 24, 2023
efbbd49
compute output shapes
DavidLandup0 Jul 26, 2023
d3b43c6
segformer progress
DavidLandup0 Jul 26, 2023
dab4e74
head
DavidLandup0 Jul 27, 2023
1dba059
softmax
DavidLandup0 Jul 27, 2023
bdc3687
remove softmax
DavidLandup0 Jul 28, 2023
ddfa315
undo compute_output_shapes()
DavidLandup0 Jul 28, 2023
5a091b6
efficientmultiheadattention -> segformermultiheadattention
DavidLandup0 Jul 30, 2023
4e9df16
docstrings
DavidLandup0 Jul 30, 2023
278875c
softmax output
DavidLandup0 Jul 30, 2023
884c376
Merge branch 'master' into segformer_tf
DavidLandup0 Jul 30, 2023
6618a65
segformer presets
DavidLandup0 Aug 1, 2023
e1fbdb0
Merge branch 'segformer_tf' of https://github.com/DavidLandup0/keras-…
DavidLandup0 Aug 1, 2023
00ecd92
updating segformer presets
DavidLandup0 Aug 1, 2023
97d9d4a
segformer presets
DavidLandup0 Aug 18, 2023
c10963f
import aliases
DavidLandup0 Aug 18, 2023
f882b3e
Merge branch 'master' into segformer_tf
DavidLandup0 Aug 18, 2023
ab10136
refactoring
DavidLandup0 Aug 18, 2023
094189e
pr comments
DavidLandup0 Aug 18, 2023
a4df0a6
pr comments
DavidLandup0 Aug 18, 2023
e22a15e
add aliases
DavidLandup0 Aug 18, 2023
5d63d18
aliases ot init
DavidLandup0 Aug 18, 2023
03a177f
refactor fix
DavidLandup0 Aug 18, 2023
d1cdd5d
import keras_cv_export
DavidLandup0 Aug 18, 2023
ff32d63
fix presets/aliases and add copyright
DavidLandup0 Aug 19, 2023
5f3fc22
linter warnings
DavidLandup0 Aug 19, 2023
c6b454f
linter errors
DavidLandup0 Aug 19, 2023
5ac7f77
consistency in presets
DavidLandup0 Aug 19, 2023
b2a76ce
return config
DavidLandup0 Aug 19, 2023
0ad5879
fix serialization
DavidLandup0 Aug 19, 2023
eea5e3c
Some cleanup + more tests
ianstenbit Aug 21, 2023
8e62cf6
Fix DropPath layer (need to update tests + add shim for tf.keras
ianstenbit Aug 21, 2023
b9efeb1
Finish DropPath layer
ianstenbit Aug 21, 2023
bd5a99f
Use static shape in backbone
ianstenbit Aug 21, 2023
3d29b0a
Formatting
ianstenbit Aug 21, 2023
4e2c4e8
Switch back to ops.shape
ianstenbit Aug 21, 2023
b32e0cf
documentation
DavidLandup0 Aug 23, 2023
743a3bb
documentation
DavidLandup0 Aug 23, 2023
c640fc9
remove default num classes
DavidLandup0 Aug 23, 2023
f1b5ffa
fix docs
DavidLandup0 Aug 23, 2023
e32704b
Merge branch 'master' into segformer_tf
ianstenbit Aug 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__/
.vscode/
.devcontainer/
.coverage
.history
9 changes: 9 additions & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
from tensorflow.keras.layers import RandomHeight
from tensorflow.keras.layers import RandomWidth

from keras_cv.layers.efficient_multihead_attention import (
EfficientMultiheadAttention,
)
from keras_cv.layers.feature_pyramid import FeaturePyramid
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.hierarchical_transformer_encoder import (
HierarchicalTransformerEncoder,
)
from keras_cv.layers.mbconv import MBConvBlock
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
Expand All @@ -31,6 +37,9 @@
CenterNetLabelEncoder,
)
from keras_cv.layers.object_detection_3d.voxelization import DynamicVoxelization
from keras_cv.layers.overlapping_patching_embedding import (
OverlappingPatchingAndEmbedding,
)
from keras_cv.layers.preprocessing.aug_mix import AugMix
from keras_cv.layers.preprocessing.auto_contrast import AutoContrast
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
Expand Down
87 changes: 87 additions & 0 deletions keras_cv/layers/efficient_multihead_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from keras_cv.backend import keras

"""
Based on: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py
"""


@keras.saving.register_keras_serializable(package="keras_cv")
class EfficientMultiheadAttention(keras.layers.Layer):
def __init__(self, project_dim, num_heads, sr_ratio):
super().__init__()
self.num_heads = num_heads
self.sr_ratio = sr_ratio
self.scale = (project_dim // num_heads) ** -0.5
self.q = keras.layers.Dense(project_dim)
self.k = keras.layers.Dense(project_dim)
self.v = keras.layers.Dense(project_dim)
self.proj = keras.layers.Dense(project_dim)

if sr_ratio > 1:
self.sr = keras.layers.Conv2D(
filters=project_dim,
kernel_size=sr_ratio,
strides=sr_ratio,
padding="same",
)
self.norm = keras.layers.LayerNormalization()

def call(self, x, H, W):
input_shape = x.shape

q = self.q(x)
q = keras.ops.reshape.reshape(
q,
(
input_shape[0],
input_shape[1],
self.num_heads,
input_shape[2] // self.num_heads,
),
)

q = q.transpose([0, 2, 1, 3])

if self.sr_ratio > 1:
x = keras.ops.reshape(
keras.ops.transpose(x, [0, 2, 1]),
(input_shape[0], H, W, input_shape[2]),
)
x = self.sr(x)
x = keras.ops.reshape(x, [input_shape[0], input_shape[2], -1])
x = keras.ops.transpose(x, [0, 2, 1])
x = self.norm(x)

k = self.k(x)
v = self.v(x)

k = keras.ops.reshape(
keras.ops.transpose(k, [0, 2, 1, 3]),
[
input_shape[0],
-1,
self.num_heads,
input_shape[2] // self.num_heads,
],
)

v = keras.ops.reshape(
keras.ops.transpose(v, [0, 2, 1, 3]),
[
input_shape[0],
-1,
self.num_heads,
input_shape[2] // self.num_heads,
],
)

attn = (q @ keras.ops.transpose(x, [0, 1, 3, 2])) * self.scale
attn = keras.nn.ops.softmax(attn, axis=-1)

attn = attn @ v
attn = keras.ops.reshape(
keras.ops.transpose(attn, [0, 2, 1, 3]),
[input_shape[0], input_shape[1], input_shape[2]],
)
x = self.proj(attn)
return x
57 changes: 57 additions & 0 deletions keras_cv/layers/hierarchical_transformer_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from keras_cv.backend import keras
from keras_cv.layers.efficient_multihead_attention import (
EfficientMultiheadAttention,
)
from keras_cv.layers.regularization.stochastic_depth import StochasticDepth


@keras.saving.register_keras_serializable(package="keras_cv")
class HierarchicalTransformerEncoder(keras.layers.Layer):
def __init__(
self,
project_dim,
num_heads,
sr_ratio=1,
drop_prob=0.0,
layer_norm_epsilon=1e-6,
**kwargs,
):
super().__init__(**kwargs)
self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.attn = EfficientMultiheadAttention(
project_dim, num_heads, sr_ratio
)
self.drop_path = StochasticDepth(drop_prob)
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
self.mlp = self.__MixFFN(
channels=project_dim,
mid_channels=int(project_dim * 4),
)

def call(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x

class __MixFFN(keras.layers.Layer):
def __init__(self, channels, mid_channels):
super().__init__()
self.fc1 = keras.layers.Dense(mid_channels)
self.dwconv = keras.layers.DepthwiseConv2D(
kernel_size=3,
strides=1,
padding="same",
)
self.fc2 = keras.layers.Dense(channels)

def call(self, x, H, W):
x = self.fc1(x)
# B, DIM, C
input_shape = x.shape

x = keras.ops.reshape(x, (input_shape[0], H, W, input_shape[-1]))
x = self.dwconv(x)
x = keras.ops.reshape(x, (input_shape[0], -1, input_shape[-1]))
x = keras.nn.ops.gelu(x)
x = self.fc2(x)
return x
22 changes: 22 additions & 0 deletions keras_cv/layers/overlapping_patching_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from keras_cv.backend import keras


@keras.saving.register_keras_serializable(package="keras_cv")
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
def __init__(self, out_channels=32, patch_size=7, stride=4, **kwargs):
super().__init__(**kwargs)
self.proj = keras.layers.Conv2D(
filters=out_channels,
kernel_size=patch_size,
strides=stride,
padding="same",
)
self.norm = keras.layers.LayerNormalization()

def call(self, x):
x = self.proj(x)
# B, H, W, C
shape = x.shape
x = keras.ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
x = self.norm(x)
return x, shape[1], shape[2]
4 changes: 4 additions & 0 deletions keras_cv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
from keras_cv.models.backbones.efficientnet_v2.efficientnet_v2_aliases import (
EfficientNetV2SBackbone,
)
from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone import (
MiTBackbone,
)
from keras_cv.models.backbones.mobilenet_v3.mobilenet_v3_aliases import (
MobileNetV3LargeBackbone,
)
Expand Down Expand Up @@ -124,5 +127,6 @@
MultiHeadCenterPillar,
)
from keras_cv.models.segmentation import DeepLabV3Plus
from keras_cv.models.segmentation import SegFormer
from keras_cv.models.stable_diffusion import StableDiffusion
from keras_cv.models.stable_diffusion import StableDiffusionV2
13 changes: 13 additions & 0 deletions keras_cv/models/backbones/mix_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
138 changes: 138 additions & 0 deletions keras_cv/models/backbones/mix_transformer/mix_transformer_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MiT backbone model.

References:

""" # noqa: E501

import numpy as np

from keras_cv import layers as cv_layers
from keras_cv.backend import keras
from keras_cv.models import utils
from keras_cv.models.backbones.backbone import Backbone
from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone_presets import ( # noqa: E501
backbone_presets,
)
from keras_cv.models.backbones.mix_transformer.mix_transformer_backbone_presets import ( # noqa: E501
backbone_presets_with_weights,
)
from keras_cv.utils.python_utils import classproperty


@keras.saving.register_keras_serializable(package="keras_cv.models")
class MiTBackbone(Backbone):
def __init__(
self,
input_shape=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to (None, None, 3) so that channel dims can be known at build time for conv layers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This'll have to default to (224, 224, 3) actually, since the input shape will have to be known at instantiation time

input_tensor=None,
classes=None,
include_top=None,
embed_dims=None,
depths=None,
pooling=None,
**kwargs,
):
if include_top and not classes:
raise ValueError(
"If `include_top` is True, you should specify `classes`. "
f"Received: classes={classes}"
)

if include_top and pooling:
raise ValueError(
f"`pooling` must be `None` when `include_top=True`."
f"Received pooling={pooling} and include_top={include_top}. "
)

drop_path_rate = 0.1
dpr = [x for x in np.linspace(0.0, drop_path_rate, sum(depths))]
blockwise_num_heads = [1, 2, 5, 8]
blockwise_sr_ratios = [8, 4, 2, 1]
num_stages = 4

cur = 0
patch_embedding_layers = []
transformer_blocks = []
layer_norms = []

for i in range(num_stages):
patch_embed_layer = cv_layers.OverlappingPatchingAndEmbedding(
out_channels=embed_dims[0] if i == 0 else embed_dims[i],
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
name=f"patch_and_embed_{i}",
)
patch_embedding_layers.append(patch_embed_layer)

transformer_block = [
cv_layers.HierarchicalTransformerEncoder(
project_dim=embed_dims[i],
num_heads=blockwise_num_heads[i],
sr_ratio=blockwise_sr_ratios[i],
drop_prob=dpr[cur + k],
name=f"hierarchical_encoder_{i}_{k}",
)
for k in range(depths[i])
]
transformer_blocks.append(transformer_block)
cur += depths[i]
layer_norms.append(keras.layers.LayerNormalization())

inputs = utils.parse_model_inputs(input_shape, input_tensor)
x = inputs

batch_size = x.shape[0]
pyramid_level_inputs = []
for i in range(num_stages):
x, H, W = patch_embedding_layers[i](x)
for blk in transformer_blocks[i]:
x = blk(x, H, W)
x = layer_norms[i](x)
C = x.shape[-1]
x = x.reshape((batch_size, H, W, C))
pyramid_level_inputs.append(x)

super().__init__(
inputs=inputs,
outputs=x,
**kwargs,
)

self.channels = embed_dims
self.num_stages = num_stages
self.output_channels = embed_dims
self.classes = classes
self.include_top = include_top
self.pyramid_level_inputs = pyramid_level_inputs
self.pooling = pooling

self.patch_embedding_layers = []
self.transformer_blocks = []

def get_config(self):
config = super().get_config()
config.update(
{
"channels": self.channels,
"num_stages": self.num_stages,
"output_channels": self.output_channels,
"classes": self.classes,
"include_top": self.include_top,
"pooling": self.pooling,
}
)
return config
Loading