|
| 1 | +# Copyright 2023 The KerasCV Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import math |
| 16 | + |
| 17 | +from keras_cv.api_export import keras_cv_export |
| 18 | +from keras_cv.backend import keras |
| 19 | +from keras_cv.backend import ops |
| 20 | +from keras_cv.layers.regularization.drop_path import DropPath |
| 21 | +from keras_cv.layers.segformer_multihead_attention import ( |
| 22 | + SegFormerMultiheadAttention, |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +@keras_cv_export("keras_cv.layers.HierarchicalTransformerEncoder") |
| 27 | +class HierarchicalTransformerEncoder(keras.layers.Layer): |
| 28 | + """ |
| 29 | + Hierarchical transformer encoder block implementation as a Keras Layer. |
| 30 | + The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention` |
| 31 | + alternative for computational efficiency, and is meant to be used |
| 32 | + within the SegFormer architecture. |
| 33 | +
|
| 34 | + References: |
| 35 | + - [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) (CVPR 2021) # noqa: E501 |
| 36 | + - [Official PyTorch implementation](https://github.com/NVlabs/SegFormer/blob/master/mmseg/models/backbones/mix_transformer.py) # noqa: E501 |
| 37 | + - [Ported from the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/blob/main/deepvision/layers/hierarchical_transformer_encoder.py) # noqa: E501 |
| 38 | +
|
| 39 | + Args: |
| 40 | + project_dim: integer, the dimensionality of the projection of the |
| 41 | + encoder, and output of the `SegFormerMultiheadAttention` layer. |
| 42 | + Due to the residual addition the input dimensionality has to be |
| 43 | + equal to the output dimensionality. |
| 44 | + num_heads: integer, the number of heads for the |
| 45 | + `SegFormerMultiheadAttention` layer. |
| 46 | + drop_prob: float, the probability of dropping a random |
| 47 | + sample using the `DropPath` layer. Defaults to `0.0`. |
| 48 | + layer_norm_epsilon: float, the epsilon for |
| 49 | + `LayerNormalization` layers. Defaults to `1e-06` |
| 50 | + sr_ratio: integer, the ratio to use within |
| 51 | + `SegFormerMultiheadAttention`. If set to > 1, a `Conv2D` |
| 52 | + layer is used to reduce the length of the sequence. Defaults to `1`. |
| 53 | +
|
| 54 | + Basic usage: |
| 55 | +
|
| 56 | + ``` |
| 57 | + project_dim = 1024 |
| 58 | + num_heads = 4 |
| 59 | + patch_size = 16 |
| 60 | +
|
| 61 | + encoded_patches = keras_cv.layers.OverlappingPatchingAndEmbedding( |
| 62 | + project_dim=project_dim, patch_size=patch_size)(img_batch) |
| 63 | +
|
| 64 | + trans_encoded = keras_cv.layers.HierarchicalTransformerEncoder(project_dim=project_dim, |
| 65 | + num_heads=num_heads, |
| 66 | + sr_ratio=1)(encoded_patches) |
| 67 | +
|
| 68 | + print(trans_encoded.shape) # (1, 3136, 1024) |
| 69 | + ``` |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__( |
| 73 | + self, |
| 74 | + project_dim, |
| 75 | + num_heads, |
| 76 | + sr_ratio=1, |
| 77 | + drop_prob=0.0, |
| 78 | + layer_norm_epsilon=1e-6, |
| 79 | + **kwargs, |
| 80 | + ): |
| 81 | + super().__init__(**kwargs) |
| 82 | + self.project_dim = project_dim |
| 83 | + self.num_heads = num_heads |
| 84 | + self.drop_prop = drop_prob |
| 85 | + |
| 86 | + self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) |
| 87 | + self.attn = SegFormerMultiheadAttention( |
| 88 | + project_dim, num_heads, sr_ratio |
| 89 | + ) |
| 90 | + self.drop_path = DropPath(drop_prob) |
| 91 | + self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon) |
| 92 | + self.mlp = self.MixFFN( |
| 93 | + channels=project_dim, |
| 94 | + mid_channels=int(project_dim * 4), |
| 95 | + ) |
| 96 | + |
| 97 | + def build(self, input_shape): |
| 98 | + super().build(input_shape) |
| 99 | + self.H = ops.sqrt(ops.cast(input_shape[1], "float32")) |
| 100 | + self.W = ops.sqrt(ops.cast(input_shape[2], "float32")) |
| 101 | + |
| 102 | + def call(self, x): |
| 103 | + x = x + self.drop_path(self.attn(self.norm1(x))) |
| 104 | + x = x + self.drop_path(self.mlp(self.norm2(x))) |
| 105 | + return x |
| 106 | + |
| 107 | + def get_config(self): |
| 108 | + config = super().get_config() |
| 109 | + config.update( |
| 110 | + { |
| 111 | + "mlp": keras.saving.serialize_keras_object(self.mlp), |
| 112 | + "project_dim": self.project_dim, |
| 113 | + "num_heads": self.num_heads, |
| 114 | + "drop_prop": self.drop_prop, |
| 115 | + } |
| 116 | + ) |
| 117 | + return config |
| 118 | + |
| 119 | + class MixFFN(keras.layers.Layer): |
| 120 | + def __init__(self, channels, mid_channels): |
| 121 | + super().__init__() |
| 122 | + self.fc1 = keras.layers.Dense(mid_channels) |
| 123 | + self.dwconv = keras.layers.DepthwiseConv2D( |
| 124 | + kernel_size=3, |
| 125 | + strides=1, |
| 126 | + padding="same", |
| 127 | + ) |
| 128 | + self.fc2 = keras.layers.Dense(channels) |
| 129 | + |
| 130 | + def call(self, x): |
| 131 | + x = self.fc1(x) |
| 132 | + shape = ops.shape(x) |
| 133 | + H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1])) |
| 134 | + B, C = shape[0], shape[2] |
| 135 | + x = ops.reshape(x, (B, H, W, C)) |
| 136 | + x = self.dwconv(x) |
| 137 | + x = ops.reshape(x, (B, -1, C)) |
| 138 | + x = ops.nn.gelu(x) |
| 139 | + x = self.fc2(x) |
| 140 | + return x |
0 commit comments