diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 0fe7b300fa..c216befdb9 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -40,6 +40,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter, ) +from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import ( + MiTImageConverter, +) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 1450ddceb3..ffe38ad002 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -206,6 +206,9 @@ from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( MiTImageClassifier, ) +from keras_hub.src.models.mix_transformer.mix_transformer_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, +) from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( MobileNetImageClassifier, diff --git a/keras_hub/src/models/mix_transformer/__init__.py b/keras_hub/src/models/mix_transformer/__init__.py index e69de29bb2..f2292a35ad 100644 --- a/keras_hub/src/models/mix_transformer/__init__.py +++ b/keras_hub/src/models/mix_transformer/__init__.py @@ -0,0 +1,12 @@ +from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_hub.src.models.mix_transformer.mix_transformer_classifier import ( + MiTImageClassifier, +) +from keras_hub.src.models.mix_transformer.mix_transformer_presets import ( + backbone_presets, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, MiTBackbone) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py index 0daac9327f..beab6646ba 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py @@ -3,8 +3,12 @@ from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( MiTBackbone, ) +from keras_hub.src.models.mix_transformer.mix_transformer_classifier_preprocessor import ( + MiTImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.MiTImageClassifier") class MiTImageClassifier(ImageClassifier): backbone_cls = MiTBackbone + preprocessor_cls = MiTImageClassifierPreprocessor diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py b/keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py new file mode 100644 index 0000000000..61c994c5fb --- /dev/null +++ b/keras_hub/src/models/mix_transformer/mix_transformer_classifier_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mix_transformer.mix_transformer_backbone import ( + MiTBackbone, +) +from keras_hub.src.models.mix_transformer.mix_transformer_image_converter import ( + MiTImageConverter, +) + + +@keras_hub_export("keras_hub.models.MiTImageClassifierPreprocessor") +class MiTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MiTBackbone + image_converter_cls = MiTImageConverter diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py b/keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py new file mode 100644 index 0000000000..e59ea26b66 --- /dev/null +++ b/keras_hub/src/models/mix_transformer/mix_transformer_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mix_transformer import MiTBackbone + + +@keras_hub_export("keras_hub.layers.MiTImageConverter") +class MiTImageConverter(ImageConverter): + backbone_cls = MiTBackbone diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py b/keras_hub/src/models/mix_transformer/mix_transformer_layers.py index 42402da7ea..fc5180ca90 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_layers.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_layers.py @@ -28,19 +28,23 @@ def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs): self.patch_size = patch_size self.stride = stride + padding_size = self.patch_size // 2 + + self.padding = keras.layers.ZeroPadding2D( + padding=(padding_size, padding_size) + ) self.proj = keras.layers.Conv2D( filters=project_dim, kernel_size=patch_size, strides=stride, - padding="same", + padding="valid", ) - self.norm = keras.layers.LayerNormalization() + self.norm = keras.layers.LayerNormalization(epsilon=1e-5) def call(self, x): + x = self.padding(x) x = self.proj(x) - # B, H, W, C - shape = x.shape - x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3])) + x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3])) x = self.norm(x) return x diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_presets.py b/keras_hub/src/models/mix_transformer/mix_transformer_presets.py new file mode 100644 index 0000000000..840fa9e1d3 --- /dev/null +++ b/keras_hub/src/models/mix_transformer/mix_transformer_presets.py @@ -0,0 +1,151 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MiT model preset configurations.""" + +backbone_presets_with_weights = { + "mit_b0_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b0_ade20k_512", + }, + "mit_b1_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b1_ade20k_512", + }, + "mit_b2_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b2_ade20k_512", + }, + "mit_b3_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b3_ade20k_512", + }, + "mit_b4_ade20k_512": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b4_ade20k_512", + }, + "mit_b5_ade20k_640": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b5_ade20k_512", + }, + "mit_b0_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 3321962, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b0_cityscapes_1024", + }, + "mit_b1_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 8 transformer blocks." + ), + "params": 13156554, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b1_cityscapes_1024", + }, + "mit_b2_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 16 transformer blocks." + ), + "params": 24201418, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b2_cityscapes_1024", + }, + "mit_b3_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 28 transformer blocks." + ), + "params": 44077258, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b3_cityscapes_1024", + }, + "mit_b4_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 41 transformer blocks." + ), + "params": 60847818, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b4_cityscapes_1024", + }, + "mit_b5_cityscapes_1024": { + "metadata": { + "description": ( + "MiT (MixTransformer) model with 52 transformer blocks." + ), + "params": 81448138, + "official_name": "MiT", + "path": "mit", + }, + "kaggle_handle": "kaggle://kerashub/mix-transformer/keras/mit_b5_cityscapes_1024", + }, +} + +backbone_presets = { + **backbone_presets_with_weights, +} diff --git a/tools/checkpoint_conversion/convert_mix_transformer.py b/tools/checkpoint_conversion/convert_mix_transformer.py new file mode 100644 index 0000000000..6419cc405e --- /dev/null +++ b/tools/checkpoint_conversion/convert_mix_transformer.py @@ -0,0 +1,196 @@ +# Usage example +# python tools/checkpoint_conversion/convert_mix_transformer.py --preset "B0_ade_512" + +from absl import app +from absl import flags +from transformers import SegformerForSemanticSegmentation + +import keras_hub + +FLAGS = flags.FLAGS + + +DOWNLOAD_URLS = { + "B0_ade_512": "nvidia/segformer-b0-finetuned-ade-512-512", + "B1_ade_512": "nvidia/segformer-b1-finetuned-ade-512-512", + "B2_ade_512": "nvidia/segformer-b2-finetuned-ade-512-512", + "B3_ade_512": "nvidia/segformer-b3-finetuned-ade-512-512", + "B4_ade_512": "nvidia/segformer-b4-finetuned-ade-512-512", + "B5_ade_640": "nvidia/segformer-b5-finetuned-ade-640-640", + "B0_cityscapes_1024": "nvidia/segformer-b0-finetuned-cityscapes-1024-1024", + "B1_cityscapes_1024": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024", + "B2_cityscapes_1024": "nvidia/segformer-b2-finetuned-cityscapes-1024-1024", + "B3_cityscapes_1024": "nvidia/segformer-b3-finetuned-cityscapes-1024-1024", + "B4_cityscapes_1024": "nvidia/segformer-b4-finetuned-cityscapes-1024-1024", + "B5_cityscapes_1024": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", +} + + +MODEL_CONFIGS = { + "B0": {"hidden_dims": [32, 64, 160, 256], "depths": [2, 2, 2, 2]}, + "B1": {"hidden_dims": [64, 128, 320, 512], "depths": [2, 2, 2, 2]}, + "B2": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 6, 3]}, + "B3": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 4, 18, 3]}, + "B4": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 8, 27, 3]}, + "B5": {"hidden_dims": [64, 128, 320, 512], "depths": [3, 6, 40, 3]}, +} + +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' +) + + +def get_indices_from_depths(depths): + proj_indices = [] + norm_indices = [] + hierarchical_encoder_indices = [] + + current_layer_idx = 1 + + for layer_idx, depth in enumerate(depths): + # Add projection index (before the hierarchical encoders) + proj_indices.append(current_layer_idx) + + # Hierarchical encoder block indices + for block_idx in range(depth): + hierarchical_encoder_indices.append( + (current_layer_idx + 1, layer_idx, block_idx) + ) + current_layer_idx += 1 + + # Add normalization index (after the hierarchical encoders) + norm_indices.append(current_layer_idx + 1) + + # Skip to the next layer after output_level + current_layer_idx += 3 + + return proj_indices, norm_indices, hierarchical_encoder_indices + + +def set_conv_weights(conv_layer, state_dict): + conv_weights = state_dict["weight"].numpy().transpose(2, 3, 1, 0) + conv_bias = state_dict["bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_dwconv_weights(conv_layer, state_dict): + conv_weights = state_dict["dwconv.weight"].numpy().transpose(2, 3, 0, 1) + conv_bias = state_dict["dwconv.bias"].numpy() + conv_layer.set_weights([conv_weights, conv_bias]) + + +def set_layer_norm_weights(layer_norm, state_dict): + gamma = state_dict["weight"].numpy() + beta = state_dict["bias"].numpy() + layer_norm.set_weights([gamma, beta]) + + +def set_dense_weights(dense_layer, state_dict): + weight = state_dict["weight"].numpy().T + bias = state_dict["bias"].numpy() + dense_layer.set_weights([weight, bias]) + + +def set_hierarchical_encoder_weights(keras_layer, pytorch_layer, key): + + set_layer_norm_weights( + keras_layer.norm1, pytorch_layer.layer_norm_1.state_dict() + ) + + set_dense_weights( + keras_layer.attn.q, pytorch_layer.attention.self.query.state_dict() + ) + set_dense_weights( + keras_layer.attn.k, pytorch_layer.attention.self.key.state_dict() + ) + set_dense_weights( + keras_layer.attn.v, pytorch_layer.attention.self.value.state_dict() + ) + set_dense_weights( + keras_layer.attn.proj, pytorch_layer.attention.output.dense.state_dict() + ) + + if keras_layer.attn.sr_ratio > 1: + set_conv_weights( + keras_layer.attn.sr, pytorch_layer.attention.self.sr.state_dict() + ) + set_layer_norm_weights( + keras_layer.attn.norm, + pytorch_layer.attention.self.layer_norm.state_dict(), + ) + + set_layer_norm_weights( + keras_layer.norm2, pytorch_layer.layer_norm_2.state_dict() + ) + + set_dense_weights( + keras_layer.mlp.fc1, pytorch_layer.mlp.dense1.state_dict() + ) + set_dwconv_weights( + keras_layer.mlp.dwconv, pytorch_layer.mlp.dwconv.state_dict() + ) + set_dense_weights( + keras_layer.mlp.fc2, pytorch_layer.mlp.dense2.state_dict() + ) + + +def main(_): + print("\n-> Loading HuggingFace model") + model = SegformerForSemanticSegmentation.from_pretrained( + DOWNLOAD_URLS[FLAGS.preset] + ) + original_mit = original_mit = model.segformer.encoder + + model_type = FLAGS.preset.split("_")[0] + print("\n-> Instantiating KerasHub Model") + keras_mit = keras_hub.models.MiTBackbone( + depths=MODEL_CONFIGS[model_type]["depths"], + image_shape=(224, 224, 3), + hidden_dims=MODEL_CONFIGS[model_type]["hidden_dims"], + num_layers=4, + blockwise_num_heads=[1, 2, 5, 8], + blockwise_sr_ratios=[8, 4, 2, 1], + max_drop_path_rate=0.1, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + ) + + # Indices for the different patch embeddings and layer norms + proj_indices, layer_norm_indices, hierarchical_encoder_indices = ( + get_indices_from_depths(MODEL_CONFIGS[model_type]["depths"]) + ) + + print("\n-> Converting weights...") + # Loop through the indices to set convolutional and normalization weights + for i, idx in enumerate(proj_indices): + set_conv_weights( + keras_mit.layers[idx].proj, + original_mit.patch_embeddings[i].proj.state_dict(), + ) + set_layer_norm_weights( + keras_mit.layers[idx].norm, + original_mit.patch_embeddings[i].layer_norm.state_dict(), + ) + + # Set layer normalization weights + for i, idx in enumerate(layer_norm_indices): + set_layer_norm_weights( + keras_mit.layers[idx], original_mit.layer_norm[i].state_dict() + ) + + # Set hierarchical encoder weights + for layer_idx, block_idx, key in hierarchical_encoder_indices: + set_hierarchical_encoder_weights( + keras_mit.layers[layer_idx], + original_mit.block[block_idx][int(key)], + key=key, + ) + + directory = f"MiT_{FLAGS.preset}" + print(f"\n-> Saving converted KerasHub model in {directory}") + keras_mit.save_to_preset(directory) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)