Skip to content

Commit 14ff57e

Browse files
authored
Merge pull request #242 from roboflow/upload_weights_file_name
specify model weights file name for uploading custom weights
2 parents f0609a9 + 775a5d3 commit 14ff57e

File tree

3 files changed

+67
-10
lines changed

3 files changed

+67
-10
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1515
from roboflow.util.general import write_line
1616

17-
__version__ = "1.1.23"
17+
__version__ = "1.1.24"
1818

1919

2020
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/version.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,13 @@ def live_plot(epochs, mAP, loss, title=""):
419419
return self.model
420420

421421
# @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")])
422-
def deploy(self, model_type: str, model_path: str) -> None:
423-
"""Uploads provided weights file to Roboflow
422+
def deploy(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None:
423+
"""Uploads provided weights file to Roboflow.
424424
425425
Args:
426-
model_path (str): File path to model weights to be uploaded
426+
model_type (str): The type of the model to be deployed.
427+
model_path (str): File path to the model weights to be uploaded.
428+
filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
427429
"""
428430

429431
supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"]
@@ -432,7 +434,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
432434
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))
433435

434436
if "yolonas" in model_type:
435-
self.deploy_yolonas(model_type, model_path)
437+
self.deploy_yolonas(model_type, model_path, filename)
436438
return
437439

438440
if "yolov8" in model_type:
@@ -457,7 +459,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
457459
" Please install it with `pip install torch`"
458460
)
459461

460-
model = torch.load(os.path.join(model_path, "weights/best.pt"))
462+
model = torch.load(os.path.join(model_path, filename))
461463

462464
if isinstance(model["model"].names, list):
463465
class_names = model["model"].names
@@ -542,7 +544,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
542544

543545
self.upload_zip(model_type, model_path)
544546

545-
def deploy_yolonas(self, model_type: str, model_path: str) -> None:
547+
def deploy_yolonas(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None:
546548
try:
547549
import torch
548550
except ImportError:
@@ -551,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
551553
" Please install it with `pip install torch`"
552554
)
553555

554-
model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
556+
model = torch.load(os.path.join(model_path, filename), map_location="cpu")
555557
class_names = model["processing_params"]["class_names"]
556558

557559
opt_path = os.path.join(model_path, "opt.yaml")
@@ -584,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
584586
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
585587
json.dump(model_artifacts, fp)
586588

587-
shutil.copy(os.path.join(model_path, "weights/best.pt"), os.path.join(model_path, "state_dict.pt"))
589+
shutil.copy(os.path.join(model_path, filename), os.path.join(model_path, "state_dict.pt"))
588590

589591
list_files = [
590592
"results.json",
@@ -602,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
602604
compress_type=zipfile.ZIP_DEFLATED,
603605
)
604606
else:
605-
if file in ["model_artifacts.json", "best.pt"]:
607+
if file in ["model_artifacts.json", filename]:
606608
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))
607609

608610
self.upload_zip(model_type, model_path)

roboflow/roboflowpy.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def upload_image(args):
6868
)
6969

7070

71+
def upload_model(args):
72+
rf = roboflow.Roboflow(args.api_key)
73+
workspace = rf.workspace(args.workspace)
74+
project = workspace.project(args.project)
75+
version = project.version(args.version_number)
76+
print(args.model_type, args.model_path, args.filename)
77+
version.deploy(str(args.model_type), str(args.model_path), str(args.filename))
78+
79+
7180
def list_projects(args):
7281
rf = roboflow.Roboflow()
7382
workspace = rf.workspace(args.workspace)
@@ -145,6 +154,7 @@ def _argparser():
145154
_add_infer_parser(subparsers)
146155
_add_projects_parser(subparsers)
147156
_add_workspaces_parser(subparsers)
157+
_add_upload_model_parser(subparsers)
148158
return parser
149159

150160

@@ -347,6 +357,51 @@ def _add_infer_parser(subparsers):
347357
infer_parser.set_defaults(func=infer)
348358

349359

360+
def _add_upload_model_parser(subparsers):
361+
upload_model_parser = subparsers.add_parser(
362+
"upload_model",
363+
help="Upload a trained model to Roboflow",
364+
)
365+
upload_model_parser.add_argument(
366+
"-a",
367+
dest="api_key",
368+
help="api_key",
369+
)
370+
upload_model_parser.add_argument(
371+
"-w",
372+
dest="workspace",
373+
help="specify a workspace url or id (will use default workspace if not specified)",
374+
)
375+
upload_model_parser.add_argument(
376+
"-p",
377+
dest="project",
378+
help="project_id to upload the model into",
379+
)
380+
upload_model_parser.add_argument(
381+
"-v",
382+
dest="version_number",
383+
type=int,
384+
help="version number to upload the model to",
385+
)
386+
upload_model_parser.add_argument(
387+
"-t",
388+
dest="model_type",
389+
help="type of the model (e.g., yolov8, yolov5)",
390+
)
391+
upload_model_parser.add_argument(
392+
"-m",
393+
dest="model_path",
394+
help="path to the trained model file",
395+
)
396+
upload_model_parser.add_argument(
397+
"-f",
398+
dest="filename",
399+
default="weights/best.pt",
400+
help="name of the model file",
401+
)
402+
upload_model_parser.set_defaults(func=upload_model)
403+
404+
350405
def _add_login_parser(subparsers):
351406
login_parser = subparsers.add_parser("login", help="Log in to Roboflow")
352407
login_parser.set_defaults(func=login)

0 commit comments

Comments
 (0)