Skip to content

Commit a40f3cc

Browse files
Refactor yolo.py for code consistency and clarity
1 parent 1a2fe22 commit a40f3cc

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

modules/textdetector/yolov5/yolo.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from packaging.version import parse as package_version_parse
33

44
from .yolov5_utils import scale_img
5-
from copy import deepcopy
5+
from copy import deepcopy
66
from .common import *
77

88
class Detect(nn.Module):
@@ -213,7 +213,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
213213
"SPP": SPP, "SPPF": SPPF, "DWConv": DWConv, "Focus": Focus, "BottleneckCSP": BottleneckCSP,
214214
"C3": C3, "C3TR": C3TR, "C3SPP": C3SPP, "C3Ghost": C3Ghost, "Concat": Concat,
215215
"Detect": Detect, "Contract": Contract, "Expand": Expand, "nn.BatchNorm2d": nn.BatchNorm2d,
216-
"BatchNorm2d": nn.BatchNorm2d,
216+
"BatchNorm2d": nn.BatchNorm2d, "nn.Upsample": nn.Upsample, "Upsample": nn.Upsample,
217217
}
218218

219219
def resolve_module(module):
@@ -262,6 +262,8 @@ def parse_arg(value):
262262
args.append([ch[x] for x in f])
263263
if isinstance(args[1], int): # number of anchors
264264
args[1] = [list(range(args[1] * 2))] * len(f)
265+
elif m is nn.Upsample:
266+
c2 = ch[f]
265267
elif m is Contract:
266268
c2 = ch[f] * args[0] ** 2
267269
elif m is Expand:
@@ -286,7 +288,7 @@ def load_yolov5(weights, map_location='cuda', fuse=True, inplace=True, out_indic
286288
ckpt = torch.load(weights, map_location=map_location) # load
287289
else:
288290
ckpt = weights
289-
291+
290292
if fuse:
291293
model = ckpt['model'].float().fuse().eval() # FP32 model
292294
else:
@@ -311,10 +313,10 @@ def load_yolov5_ckpt(weights, map_location='cpu', fuse=True, inplace=True, out_i
311313
ckpt = torch.load(weights, map_location=map_location) # load
312314
else:
313315
ckpt = weights
314-
316+
315317
model = Model(ckpt['cfg'])
316318
model.load_state_dict(ckpt['weights'], strict=True)
317-
319+
318320
if fuse:
319321
model = model.float().fuse().eval() # FP32 model
320322
else:

0 commit comments

Comments
 (0)