This repository was archived by the owner on Mar 10, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathdarknet_backbone.py
More file actions
293 lines (250 loc) · 9.26 KB
/
darknet_backbone.py
File metadata and controls
293 lines (250 loc) · 9.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DarkNet backbone model.
Reference:
- [YoloV3 Paper](https://arxiv.org/abs/1804.02767)
- [YoloV3 implementation](https://github.com/ultralytics/yolov3)
"""
import copy
from tensorflow import keras
from tensorflow.keras import layers
from keras_cv.models.backbones.backbone import Backbone
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
DarknetConvBlock,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
ResidualBlocks,
)
from keras_cv.models.backbones.csp_darknet.csp_darknet_utils import (
SpatialPyramidPoolingBottleneck,
)
from keras_cv.models.backbones.darknet.darknet_backbone_presets import (
backbone_presets,
)
from keras_cv.models.backbones.darknet.darknet_backbone_presets import (
backbone_presets_with_weights,
)
from keras_cv.models.legacy import utils
from keras_cv.utils.python_utils import classproperty
@keras.utils.register_keras_serializable(package="keras_cv.models")
class DarkNetBackbone(Backbone):
"""Represents the DarkNet architecture.
The DarkNet architecture is commonly used for detection tasks. It is
possible to extract the intermediate dark2 to dark5 layers from the model
for creating a feature pyramid Network.
Reference:
- [YoloV3 Paper](https://arxiv.org/abs/1804.02767)
- [YoloV3 implementation](https://github.com/ultralytics/yolov3)
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Args:
stackwise_blocks: integer, numbers of building blocks from the layer
dark2 to layer dark5.
include_rescaling: bool, whether to rescale the inputs. If set to True,
inputs will be passed through a `Rescaling(1/255.0)` layer.
input_shape: optional shape tuple, defaults to (None, None, 3).
input_tensor: optional Keras tensor (i.e., output of `layers.Input()`)
to use as image input for the model.
Examples:
```python
input_data = tf.ones(shape=(8, 224, 224, 3))
# Pretrained backbone
model = keras_cv.models.DarkNetBackbone.from_preset("darknet53_imagenet")
output = model(input_data)
# Randomly initialized backbone with a custom config
model = DarkNetBackbone(
stackwise_blocks=[2, 8, 8, 4],
include_rescaling=False,
)
output = model(input_data)
```
""" # noqa: E501
def __init__(
self,
stackwise_blocks,
include_rescaling,
input_shape=(None, None, 3),
input_tensor=None,
**kwargs,
):
inputs = utils.parse_model_inputs(input_shape, input_tensor)
x = inputs
if include_rescaling:
x = layers.Rescaling(1 / 255.0)(x)
# stem
pyramid_level_inputs = {}
x = DarknetConvBlock(
filters=32,
kernel_size=3,
strides=1,
activation="leaky_relu",
name="stem_conv",
)(x)
pyramid_level_inputs[2] = x.node.layer.name
x = ResidualBlocks(
filters=64, num_blocks=1, name="stem_residual_block"
)(x)
pyramid_level_inputs[3] = x.node.layer.name
# filters for the ResidualBlock outputs
filters = [128, 256, 512, 1024]
# layer_num is used for naming the residual blocks
# (starts with dark2, hence 2)
layer_num = 2
for filter, block in zip(filters, stackwise_blocks):
x = ResidualBlocks(
filters=filter,
num_blocks=block,
name=f"dark{layer_num}_residual_block",
)(x)
layer_num += 1
pyramid_level_inputs[layer_num + 1] = x.node.layer.name
# remaining dark5 layers
x = DarknetConvBlock(
filters=512,
kernel_size=1,
strides=1,
activation="leaky_relu",
name="dark5_conv1",
)(x)
pyramid_level_inputs[8] = x.node.layer.name
x = DarknetConvBlock(
filters=1024,
kernel_size=3,
strides=1,
activation="leaky_relu",
name="dark5_conv2",
)(x)
pyramid_level_inputs[9] = x.node.layer.name
x = SpatialPyramidPoolingBottleneck(
512, activation="leaky_relu", name="dark5_spp"
)(x)
x = DarknetConvBlock(
filters=1024,
kernel_size=3,
strides=1,
activation="leaky_relu",
name="dark5_conv3",
)(x)
pyramid_level_inputs[10] = x.node.layer.name
x = DarknetConvBlock(
filters=512,
kernel_size=1,
strides=1,
activation="leaky_relu",
name="dark5_conv4",
)(x)
pyramid_level_inputs[11] = x.node.layer.name
super().__init__(inputs=inputs, outputs=x, **kwargs)
self.pyramid_level_inputs = pyramid_level_inputs
self.stackwise_blocks = stackwise_blocks
self.include_rescaling = include_rescaling
self.input_tensor = input_tensor
def get_config(self):
config = super().get_config()
config.update(
{
"stackwise_blocks": self.stackwise_blocks,
"include_rescaling": self.include_rescaling,
"input_shape": self.input_shape[1:],
"input_tensor": self.input_tensor,
}
)
return config
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return copy.deepcopy(backbone_presets)
@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return copy.deepcopy(backbone_presets_with_weights)
ALIAS_DOCSTRING = """DarkNet model with {num_layers} layers.
Although the DarkNet architecture is commonly used for detection tasks, it
is possible to extract the intermediate dark2 to dark5 layers from the model
for creating a feature pyramid Network.
Reference:
- [YoloV3 Paper](https://arxiv.org/abs/1804.02767)
- [YoloV3 implementation](https://github.com/ultralytics/yolov3)
For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Args:
include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
input_shape: optional shape tuple, defaults to (None, None, 3).
input_tensor: optional Keras tensor (i.e., output of `layers.Input()`)
to use as image input for the model.
Examples:
```python
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = DarkNet{num_layers}Backbone()
output = model(input_data)
```
""" # noqa: E501
class DarkNet21Backbone(DarkNetBackbone):
def __new__(
cls,
include_rescaling=True,
input_shape=(None, None, 3),
input_tensor=None,
**kwargs,
):
# Pack args in kwargs
kwargs.update(
{
"include_rescaling": include_rescaling,
"input_shape": input_shape,
"input_tensor": input_tensor,
}
)
return DarkNetBackbone.from_preset("darknet21", **kwargs)
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {}
@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return {}
class DarkNet53Backbone(DarkNetBackbone):
def __new__(
cls,
include_rescaling=True,
input_shape=(None, None, 3),
input_tensor=None,
**kwargs,
):
# Pack args in kwargs
kwargs.update(
{
"include_rescaling": include_rescaling,
"input_shape": input_shape,
"input_tensor": input_tensor,
}
)
return DarkNetBackbone.from_preset("darknet53", **kwargs)
@classproperty
def presets(cls):
"""Dictionary of preset names and configurations."""
return {
"darknet53_imagenet": copy.deepcopy(
backbone_presets["darknet53_imagenet"]
),
}
@classproperty
def presets_with_weights(cls):
"""Dictionary of preset names and configurations that include weights.""" # noqa: E501
return cls.presets
setattr(DarkNet21Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=21))
setattr(DarkNet53Backbone, "__doc__", ALIAS_DOCSTRING.format(num_layers=53))