Skip to content

Commit

Permalink
Merge pull request #242 from roboflow/upload_weights_file_name
Browse files Browse the repository at this point in the history
specify model weights file name for uploading custom weights
  • Loading branch information
ryanjball authored Mar 14, 2024
2 parents f0609a9 + 775a5d3 commit 14ff57e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 10 deletions.
2 changes: 1 addition & 1 deletion roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from roboflow.models import CLIPModel, GazeModel # noqa: F401
from roboflow.util.general import write_line

__version__ = "1.1.23"
__version__ = "1.1.24"


def check_key(api_key, model, notebook, num_retries=0):
Expand Down
20 changes: 11 additions & 9 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,13 @@ def live_plot(epochs, mAP, loss, title=""):
return self.model

# @warn_for_wrong_dependencies_versions([("ultralytics", "==", "8.0.196")])
def deploy(self, model_type: str, model_path: str) -> None:
"""Uploads provided weights file to Roboflow
def deploy(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None:
"""Uploads provided weights file to Roboflow.
Args:
model_path (str): File path to model weights to be uploaded
model_type (str): The type of the model to be deployed.
model_path (str): File path to the model weights to be uploaded.
filename (str, optional): The name of the weights file. Defaults to "weights/best.pt".
"""

supported_models = ["yolov5", "yolov7-seg", "yolov8", "yolov9", "yolonas"]
Expand All @@ -432,7 +434,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
raise (ValueError(f"Model type {model_type} not supported. Supported models are" f" {supported_models}"))

if "yolonas" in model_type:
self.deploy_yolonas(model_type, model_path)
self.deploy_yolonas(model_type, model_path, filename)
return

if "yolov8" in model_type:
Expand All @@ -457,7 +459,7 @@ def deploy(self, model_type: str, model_path: str) -> None:
" Please install it with `pip install torch`"
)

model = torch.load(os.path.join(model_path, "weights/best.pt"))
model = torch.load(os.path.join(model_path, filename))

if isinstance(model["model"].names, list):
class_names = model["model"].names
Expand Down Expand Up @@ -542,7 +544,7 @@ def deploy(self, model_type: str, model_path: str) -> None:

self.upload_zip(model_type, model_path)

def deploy_yolonas(self, model_type: str, model_path: str) -> None:
def deploy_yolonas(self, model_type: str, model_path: str, filename: str = "weights/best.pt") -> None:
try:
import torch
except ImportError:
Expand All @@ -551,7 +553,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
" Please install it with `pip install torch`"
)

model = torch.load(os.path.join(model_path, "weights/best.pt"), map_location="cpu")
model = torch.load(os.path.join(model_path, filename), map_location="cpu")
class_names = model["processing_params"]["class_names"]

opt_path = os.path.join(model_path, "opt.yaml")
Expand Down Expand Up @@ -584,7 +586,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
with open(os.path.join(model_path, "model_artifacts.json"), "w") as fp:
json.dump(model_artifacts, fp)

shutil.copy(os.path.join(model_path, "weights/best.pt"), os.path.join(model_path, "state_dict.pt"))
shutil.copy(os.path.join(model_path, filename), os.path.join(model_path, "state_dict.pt"))

list_files = [
"results.json",
Expand All @@ -602,7 +604,7 @@ def deploy_yolonas(self, model_type: str, model_path: str) -> None:
compress_type=zipfile.ZIP_DEFLATED,
)
else:
if file in ["model_artifacts.json", "best.pt"]:
if file in ["model_artifacts.json", filename]:
raise (ValueError(f"File {file} not found. Please make sure to provide a" " valid model path."))

self.upload_zip(model_type, model_path)
Expand Down
55 changes: 55 additions & 0 deletions roboflow/roboflowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def upload_image(args):
)


def upload_model(args):
rf = roboflow.Roboflow(args.api_key)
workspace = rf.workspace(args.workspace)
project = workspace.project(args.project)
version = project.version(args.version_number)
print(args.model_type, args.model_path, args.filename)
version.deploy(str(args.model_type), str(args.model_path), str(args.filename))


def list_projects(args):
rf = roboflow.Roboflow()
workspace = rf.workspace(args.workspace)
Expand Down Expand Up @@ -145,6 +154,7 @@ def _argparser():
_add_infer_parser(subparsers)
_add_projects_parser(subparsers)
_add_workspaces_parser(subparsers)
_add_upload_model_parser(subparsers)
return parser


Expand Down Expand Up @@ -347,6 +357,51 @@ def _add_infer_parser(subparsers):
infer_parser.set_defaults(func=infer)


def _add_upload_model_parser(subparsers):
upload_model_parser = subparsers.add_parser(
"upload_model",
help="Upload a trained model to Roboflow",
)
upload_model_parser.add_argument(
"-a",
dest="api_key",
help="api_key",
)
upload_model_parser.add_argument(
"-w",
dest="workspace",
help="specify a workspace url or id (will use default workspace if not specified)",
)
upload_model_parser.add_argument(
"-p",
dest="project",
help="project_id to upload the model into",
)
upload_model_parser.add_argument(
"-v",
dest="version_number",
type=int,
help="version number to upload the model to",
)
upload_model_parser.add_argument(
"-t",
dest="model_type",
help="type of the model (e.g., yolov8, yolov5)",
)
upload_model_parser.add_argument(
"-m",
dest="model_path",
help="path to the trained model file",
)
upload_model_parser.add_argument(
"-f",
dest="filename",
default="weights/best.pt",
help="name of the model file",
)
upload_model_parser.set_defaults(func=upload_model)


def _add_login_parser(subparsers):
login_parser = subparsers.add_parser("login", help="Log in to Roboflow")
login_parser.set_defaults(func=login)
Expand Down

0 comments on commit 14ff57e

Please sign in to comment.