22from packaging .version import parse as package_version_parse
33
44from .yolov5_utils import scale_img
5- from copy import deepcopy
5+ from copy import deepcopy
66from .common import *
77
88class Detect (nn .Module ):
@@ -213,7 +213,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
213213 "SPP" : SPP , "SPPF" : SPPF , "DWConv" : DWConv , "Focus" : Focus , "BottleneckCSP" : BottleneckCSP ,
214214 "C3" : C3 , "C3TR" : C3TR , "C3SPP" : C3SPP , "C3Ghost" : C3Ghost , "Concat" : Concat ,
215215 "Detect" : Detect , "Contract" : Contract , "Expand" : Expand , "nn.BatchNorm2d" : nn .BatchNorm2d ,
216- "BatchNorm2d" : nn .BatchNorm2d ,
216+ "BatchNorm2d" : nn .BatchNorm2d , "nn.Upsample" : nn . Upsample , "Upsample" : nn . Upsample ,
217217 }
218218
219219 def resolve_module (module ):
@@ -262,6 +262,8 @@ def parse_arg(value):
262262 args .append ([ch [x ] for x in f ])
263263 if isinstance (args [1 ], int ): # number of anchors
264264 args [1 ] = [list (range (args [1 ] * 2 ))] * len (f )
265+ elif m is nn .Upsample :
266+ c2 = ch [f ]
265267 elif m is Contract :
266268 c2 = ch [f ] * args [0 ] ** 2
267269 elif m is Expand :
@@ -286,7 +288,7 @@ def load_yolov5(weights, map_location='cuda', fuse=True, inplace=True, out_indic
286288 ckpt = torch .load (weights , map_location = map_location ) # load
287289 else :
288290 ckpt = weights
289-
291+
290292 if fuse :
291293 model = ckpt ['model' ].float ().fuse ().eval () # FP32 model
292294 else :
@@ -311,10 +313,10 @@ def load_yolov5_ckpt(weights, map_location='cpu', fuse=True, inplace=True, out_i
311313 ckpt = torch .load (weights , map_location = map_location ) # load
312314 else :
313315 ckpt = weights
314-
316+
315317 model = Model (ckpt ['cfg' ])
316318 model .load_state_dict (ckpt ['weights' ], strict = True )
317-
319+
318320 if fuse :
319321 model = model .float ().fuse ().eval () # FP32 model
320322 else :
0 commit comments