|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import concurrent.futures
|
2 | 4 | import glob
|
3 | 5 | import json
|
|
10 | 12 |
|
11 | 13 | from roboflow.adapters import rfapi
|
12 | 14 | from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError
|
13 |
| -from roboflow.config import API_URL, CLIP_FEATURIZE_URL, DEMO_KEYS |
| 15 | +from roboflow.config import API_URL, APP_URL, CLIP_FEATURIZE_URL, DEMO_KEYS |
14 | 16 | from roboflow.core.project import Project
|
15 | 17 | from roboflow.util import folderparser
|
16 | 18 | from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons
|
17 | 19 | from roboflow.util.image_utils import load_labelmap
|
| 20 | +from roboflow.util.model_processor import process |
18 | 21 | from roboflow.util.two_stage_utils import ocr_infer
|
19 | 22 |
|
20 | 23 |
|
@@ -566,6 +569,73 @@ def active_learning(
|
566 | 569 | prediction_results if type(raw_data_location) is not np.ndarray else prediction_results[-1]["predictions"]
|
567 | 570 | )
|
568 | 571 |
|
| 572 | + def deploy_model( |
| 573 | + self, |
| 574 | + model_type: str, |
| 575 | + model_path: str, |
| 576 | + project_ids: list[str], |
| 577 | + model_name: str, |
| 578 | + filename: str = "weights/best.pt", |
| 579 | + ): |
| 580 | + """Uploads provided weights file to Roboflow. |
| 581 | + Args: |
| 582 | + model_type (str): The type of the model to be deployed. |
| 583 | + model_path (str): File path to the model weights to be uploaded. |
| 584 | + project_ids (list[str]): List of project IDs to deploy the model to. |
| 585 | + filename (str, optional): The name of the weights file. Defaults to "weights/best.pt". |
| 586 | + """ |
| 587 | + |
| 588 | + if not project_ids: |
| 589 | + raise ValueError("At least one project ID must be provided") |
| 590 | + |
| 591 | + # Validate if provided project URLs belong to user's projects |
| 592 | + user_projects = set(project.split("/")[-1] for project in self.projects()) |
| 593 | + for project_id in project_ids: |
| 594 | + if project_id not in user_projects: |
| 595 | + raise ValueError(f"Project {project_id} is not accessible in this workspace") |
| 596 | + |
| 597 | + zip_file_name = process(model_type, model_path, filename) |
| 598 | + |
| 599 | + if zip_file_name is None: |
| 600 | + raise RuntimeError("Failed to process model") |
| 601 | + |
| 602 | + self._upload_zip(model_type, model_path, project_ids, model_name, zip_file_name) |
| 603 | + |
| 604 | + def _upload_zip( |
| 605 | + self, |
| 606 | + model_type: str, |
| 607 | + model_path: str, |
| 608 | + project_ids: list[str], |
| 609 | + model_name: str, |
| 610 | + model_file_name: str, |
| 611 | + ): |
| 612 | + # This endpoint returns a signed URL to upload the model |
| 613 | + res = requests.post( |
| 614 | + f"{API_URL}/{self.url}/models/prepareUpload?api_key={self.__api_key}&modelType={model_type}&modelName={model_name}&projectIds={','.join(project_ids)}&nocache=true" |
| 615 | + ) |
| 616 | + try: |
| 617 | + res.raise_for_status() |
| 618 | + except Exception as e: |
| 619 | + print(f"An error occured when getting the model deployment URL: {e}") |
| 620 | + return |
| 621 | + |
| 622 | + # Upload the model to the signed URL |
| 623 | + res = requests.put( |
| 624 | + res.json()["url"], |
| 625 | + data=open(os.path.join(model_path, model_file_name), "rb"), |
| 626 | + ) |
| 627 | + try: |
| 628 | + res.raise_for_status() |
| 629 | + |
| 630 | + for project_id in project_ids: |
| 631 | + print( |
| 632 | + f"View the status of your deployment for project {project_id} at:" |
| 633 | + f" {APP_URL}/{self.url}/{project_id}/models" |
| 634 | + ) |
| 635 | + |
| 636 | + except Exception as e: |
| 637 | + print(f"An error occured when uploading the model: {e}") |
| 638 | + |
569 | 639 | def __str__(self):
|
570 | 640 | projects = self.projects()
|
571 | 641 | json_value = {"name": self.name, "url": self.url, "projects": projects}
|
|
0 commit comments