Skip to content

Commit 7071e1e

Browse files
Resolved failing test cases.
1 parent 616acd6 commit 7071e1e

10 files changed

+125
-79
lines changed

keras_hub/api/layers/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
4949
DeepLabV3ImageConverter,
5050
)
51+
from keras_hub.src.models.deit.deit_image_converter import DeiTImageConverter
5152
from keras_hub.src.models.densenet.densenet_image_converter import (
5253
DenseNetImageConverter,
5354
)
@@ -71,9 +72,6 @@
7172
RetinaNetImageConverter,
7273
)
7374
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
74-
from keras_hub.src.models.deit.deit_image_converter import (
75-
DeiTImageConverter,
76-
)
7775
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
7876
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
7977
from keras_hub.src.models.segformer.segformer_image_converter import (

keras_hub/api/models/__init__.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,18 @@
101101
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
102102
DeepLabV3ImageSegmenter,
103103
)
104+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
105+
from keras_hub.src.models.deit.deit_image_classifier import DeiTImageClassifier
106+
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
107+
DeiTImageClassifierPreprocessor,
108+
)
104109
from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
105110
from keras_hub.src.models.densenet.densenet_image_classifier import (
106111
DenseNetImageClassifier,
107112
)
108113
from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
109114
DenseNetImageClassifierPreprocessor,
110115
)
111-
from keras_hub.src.models.deit.deit_backbone import(
112-
DeiTBackbone,
113-
)
114-
from keras_hub.src.models.deit.deit_image_classifier import DeiTImageClassifier
115-
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
116-
DeiTImageClassifierPreprocessor,
117-
)
118116
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
119117
DistilBertBackbone,
120118
)

keras_hub/src/models/deit/deit_backbone.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,46 @@
22

33
from keras_hub.src.api_export import keras_hub_export
44
from keras_hub.src.models.backbone import Backbone
5-
from keras_hub.src.utils.keras_utils import standardize_data_format
6-
from keras_hub.src.models.deit.deit_layers import DeiTEncoder
75
from keras_hub.src.models.deit.deit_layers import DeiTEmbeddings
6+
from keras_hub.src.models.deit.deit_layers import DeiTEncoder
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
89

910
@keras_hub_export("keras_hub.models.DeiTBackbone")
1011
class DeiTBackbone(Backbone):
1112
"""DeiT backbone.
12-
13-
This backbone implements the Data-efficient Image Transformer (DeiT) architecture
14-
as described in [Training data-efficient image transformers & distillation through
15-
attention](https://arxiv.org/abs/2012.12877).
16-
13+
14+
This backbone implements the Data-efficient Image Transformer (DeiT)
15+
architecture as described in [Training data-efficient image
16+
transformers & distillation through attention]
17+
(https://arxiv.org/abs/2012.12877).
18+
1719
Args:
1820
image_shape: A tuple or list of 3 integers representing the shape of the
1921
input image `(height, width, channels)`, `height` and `width` must
2022
be equal.
21-
patch_size: int. The size of each image patch, the input image
22-
will be divided into patches of shape `(patch_size_h, patch_size_w)`.
23+
patch_size: int. The size of each image patch, the input image will
24+
be divided into patches of shape `(patch_size_h, patch_size_w)`.
2325
num_layers: int. The number of transformer encoder layers.
24-
num_heads: int. The number of attention heads in each Transformer encoder layer.
26+
num_heads: int. The number of attention heads in each Transformer
27+
encoder layer.
2528
hidden_dim: int. The dimensionality of the hidden representations.
26-
intermediate_dim: int. The dimensionality of the intermediate MLP layer in
27-
each Transformer encoder layer.
28-
dropout_rate: float. The dropout rate for the Transformer encoder layers.
29+
intermediate_dim: int. The dimensionality of the intermediate MLP layer
30+
in each Transformer encoder layer.
31+
dropout_rate: float. The dropout rate for the Transformer encoder
32+
layers.
2933
attention_dropout: float. The dropout rate for the attention mechanism
3034
in each Transformer encoder layer.
31-
layer_norm_epsilon: float. Value used for numerical stability in layer normalization.
32-
use_mha_bias: bool. Whether to use bias in the multi-head attention layers.
35+
layer_norm_epsilon: float. Value used for numerical stability in layer
36+
normalization.
37+
use_mha_bias: bool. Whether to use bias in the multi-head attention
38+
layers.
3339
data_format: str. `"channels_last"` or `"channels_first"`, specifying
34-
the data format for the input image. If `None`, defaults to `"channels_last"`.
40+
the data format for the input image. If `None`, defaults to
41+
`"channels_last"`.
3542
dtype: The dtype of the layer weights. Defaults to None.
36-
**kwargs: Additional keyword arguments to be passed to the parent `Backbone` class.
43+
**kwargs: Additional keyword arguments to be passed to the parent
44+
`Backbone` class.
3745
"""
3846

3947
def __init__(
@@ -108,8 +116,8 @@ def __init__(
108116
attention_dropout=attention_dropout,
109117
layer_norm_epsilon=layer_norm_epsilon,
110118
dtype=dtype,
111-
name="deit_encoder"
112-
)(x)
119+
name="deit_encoder",
120+
)(x)
113121

114122
super().__init__(
115123
inputs=inputs,
@@ -130,7 +138,7 @@ def __init__(
130138
self.layer_norm_epsilon = layer_norm_epsilon
131139
self.use_mha_bias = use_mha_bias
132140
self.data_format = data_format
133-
141+
134142
def get_config(self):
135143
config = super().get_config()
136144
config.update(
@@ -147,4 +155,4 @@ def get_config(self):
147155
"use_mha_bias": self.use_mha_bias,
148156
}
149157
)
150-
return config
158+
return config

keras_hub/src/models/deit/deit_backbone_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_backbone_basics(self):
2525
init_kwargs={**self.init_kwargs},
2626
input_data=self.input_data,
2727
# 49+2 positions(49 patches, cls and distillation token)
28-
expected_output_shape=(2, 51, 48)
28+
expected_output_shape=(2, 51, 48),
2929
run_quantization_check=False,
3030
)
3131

keras_hub/src/models/deit/deit_image_classifier.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from keras import ops
33

44
from keras_hub.src.api_export import keras_hub_export
5-
from keras_hub.src.models.image_classifier import ImageClassifier
6-
from keras_hub.src.models.task import Task
75
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
86
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
97
DeiTImageClassifierPreprocessor,
108
)
9+
from keras_hub.src.models.image_classifier import ImageClassifier
10+
from keras_hub.src.models.task import Task
1111

1212

1313
@keras_hub_export("keras_hub.models.DeiTImageClassifier")

keras_hub/src/models/deit/deit_image_classifier_preprocessor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
3+
from keras_hub.src.models.deit.deit_image_converter import DeiTImageConverter
24
from keras_hub.src.models.image_classifier_preprocessor import (
35
ImageClassifierPreprocessor,
46
)
5-
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
6-
from keras_hub.src.models.deit.deit_image_converter import DeiTImageConverter
77

88

99
@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")

keras_hub/src/models/deit/deit_image_converter.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from keras_hub.src.api_export import keras_hub_export
2-
from keras_hub.src.utils.tensor_utils import preprocessing_function
32
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
43
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
4+
from keras_hub.src.utils.tensor_utils import preprocessing_function
55

66

77
@keras_hub_export("keras_hub.layers.DeiTImageConverter")
@@ -20,7 +20,9 @@ class DeiTImageConverter(ImageConverter):
2020
```python
2121
import keras
2222
import numpy as np
23-
from keras_hub.src.models.deit.deit_image_converter import DeiTImageConverter
23+
from keras_hub.src.models.deit.deit_image_converter import (
24+
DeiTImageConverter
25+
)
2426
# Example image (replace with your actual image data)
2527
image = np.random.rand(1, 384, 384, 3) # Example: (B, H, W, C)
2628
# Create a DeiTImageConverter instance
@@ -60,4 +62,4 @@ def get_config(self):
6062
"norm_std": self.norm_std,
6163
}
6264
)
63-
return config
65+
return config

0 commit comments

Comments
 (0)