2
2
3
3
from keras_hub .src .api_export import keras_hub_export
4
4
from keras_hub .src .models .backbone import Backbone
5
- from keras_hub .src .utils .keras_utils import standardize_data_format
6
- from keras_hub .src .models .deit .deit_layers import DeiTEncoder
7
5
from keras_hub .src .models .deit .deit_layers import DeiTEmbeddings
6
+ from keras_hub .src .models .deit .deit_layers import DeiTEncoder
7
+ from keras_hub .src .utils .keras_utils import standardize_data_format
8
+
8
9
9
10
@keras_hub_export ("keras_hub.models.DeiTBackbone" )
10
11
class DeiTBackbone (Backbone ):
11
12
"""DeiT backbone.
12
-
13
- This backbone implements the Data-efficient Image Transformer (DeiT) architecture
14
- as described in [Training data-efficient image transformers & distillation through
15
- attention](https://arxiv.org/abs/2012.12877).
16
-
13
+
14
+ This backbone implements the Data-efficient Image Transformer (DeiT)
15
+ architecture as described in [Training data-efficient image
16
+ transformers & distillation through attention]
17
+ (https://arxiv.org/abs/2012.12877).
18
+
17
19
Args:
18
20
image_shape: A tuple or list of 3 integers representing the shape of the
19
21
input image `(height, width, channels)`, `height` and `width` must
20
22
be equal.
21
- patch_size: int. The size of each image patch, the input image
22
- will be divided into patches of shape `(patch_size_h, patch_size_w)`.
23
+ patch_size: int. The size of each image patch, the input image will
24
+ be divided into patches of shape `(patch_size_h, patch_size_w)`.
23
25
num_layers: int. The number of transformer encoder layers.
24
- num_heads: int. The number of attention heads in each Transformer encoder layer.
26
+ num_heads: int. The number of attention heads in each Transformer
27
+ encoder layer.
25
28
hidden_dim: int. The dimensionality of the hidden representations.
26
- intermediate_dim: int. The dimensionality of the intermediate MLP layer in
27
- each Transformer encoder layer.
28
- dropout_rate: float. The dropout rate for the Transformer encoder layers.
29
+ intermediate_dim: int. The dimensionality of the intermediate MLP layer
30
+ in each Transformer encoder layer.
31
+ dropout_rate: float. The dropout rate for the Transformer encoder
32
+ layers.
29
33
attention_dropout: float. The dropout rate for the attention mechanism
30
34
in each Transformer encoder layer.
31
- layer_norm_epsilon: float. Value used for numerical stability in layer normalization.
32
- use_mha_bias: bool. Whether to use bias in the multi-head attention layers.
35
+ layer_norm_epsilon: float. Value used for numerical stability in layer
36
+ normalization.
37
+ use_mha_bias: bool. Whether to use bias in the multi-head attention
38
+ layers.
33
39
data_format: str. `"channels_last"` or `"channels_first"`, specifying
34
- the data format for the input image. If `None`, defaults to `"channels_last"`.
40
+ the data format for the input image. If `None`, defaults to
41
+ `"channels_last"`.
35
42
dtype: The dtype of the layer weights. Defaults to None.
36
- **kwargs: Additional keyword arguments to be passed to the parent `Backbone` class.
43
+ **kwargs: Additional keyword arguments to be passed to the parent
44
+ `Backbone` class.
37
45
"""
38
46
39
47
def __init__ (
@@ -108,8 +116,8 @@ def __init__(
108
116
attention_dropout = attention_dropout ,
109
117
layer_norm_epsilon = layer_norm_epsilon ,
110
118
dtype = dtype ,
111
- name = "deit_encoder"
112
- )(x )
119
+ name = "deit_encoder" ,
120
+ )(x )
113
121
114
122
super ().__init__ (
115
123
inputs = inputs ,
@@ -130,7 +138,7 @@ def __init__(
130
138
self .layer_norm_epsilon = layer_norm_epsilon
131
139
self .use_mha_bias = use_mha_bias
132
140
self .data_format = data_format
133
-
141
+
134
142
def get_config (self ):
135
143
config = super ().get_config ()
136
144
config .update (
@@ -147,4 +155,4 @@ def get_config(self):
147
155
"use_mha_bias" : self .use_mha_bias ,
148
156
}
149
157
)
150
- return config
158
+ return config
0 commit comments