@@ -419,11 +419,13 @@ def live_plot(epochs, mAP, loss, title=""):
419
419
return self .model
420
420
421
421
# @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")])
422
- def deploy (self , model_type : str , model_path : str ) -> None :
423
- """Uploads provided weights file to Roboflow
422
+ def deploy (self , model_type : str , model_path : str , filename : str = "weights/best.pt" ) -> None :
423
+ """Uploads provided weights file to Roboflow.
424
424
425
425
Args:
426
- model_path (str): File path to model weights to be uploaded
426
+ model_type (str): The type of the model to be deployed.
427
+ model_path (str): File path to the model weights to be uploaded.
428
+ filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
427
429
"""
428
430
429
431
supported_models = ["yolov5" , "yolov7-seg" , "yolov8" , "yolov9" , "yolonas" ]
@@ -432,7 +434,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
432
434
raise (ValueError (f"Model type { model_type } not supported. Supported models are" f" { supported_models } " ))
433
435
434
436
if "yolonas" in model_type :
435
- self .deploy_yolonas (model_type , model_path )
437
+ self .deploy_yolonas (model_type , model_path , filename )
436
438
return
437
439
438
440
if "yolov8" in model_type :
@@ -457,7 +459,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
457
459
" Please install it with `pip install torch`"
458
460
)
459
461
460
- model = torch .load (os .path .join (model_path , "weights/best.pt" ))
462
+ model = torch .load (os .path .join (model_path , filename ))
461
463
462
464
if isinstance (model ["model" ].names , list ):
463
465
class_names = model ["model" ].names
@@ -542,7 +544,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
542
544
543
545
self .upload_zip (model_type , model_path )
544
546
545
- def deploy_yolonas (self , model_type : str , model_path : str ) -> None :
547
+ def deploy_yolonas (self , model_type : str , model_path : str , filename : str = "weights/best.pt" ) -> None :
546
548
try :
547
549
import torch
548
550
except ImportError :
@@ -551,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
551
553
" Please install it with `pip install torch`"
552
554
)
553
555
554
- model = torch .load (os .path .join (model_path , "weights/best.pt" ), map_location = "cpu" )
556
+ model = torch .load (os .path .join (model_path , filename ), map_location = "cpu" )
555
557
class_names = model ["processing_params" ]["class_names" ]
556
558
557
559
opt_path = os .path .join (model_path , "opt.yaml" )
@@ -584,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
584
586
with open (os .path .join (model_path , "model_artifacts.json" ), "w" ) as fp :
585
587
json .dump (model_artifacts , fp )
586
588
587
- shutil .copy (os .path .join (model_path , "weights/best.pt" ), os .path .join (model_path , "state_dict.pt" ))
589
+ shutil .copy (os .path .join (model_path , filename ), os .path .join (model_path , "state_dict.pt" ))
588
590
589
591
list_files = [
590
592
"results.json" ,
@@ -602,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
602
604
compress_type = zipfile .ZIP_DEFLATED ,
603
605
)
604
606
else :
605
- if file in ["model_artifacts.json" , "best.pt" ]:
607
+ if file in ["model_artifacts.json" , filename ]:
606
608
raise (ValueError (f"File { file } not found. Please make sure to provide a" " valid model path." ))
607
609
608
610
self .upload_zip (model_type , model_path )
0 commit comments