Skip to content

Commit d531b2e

Browse files
Added get_latest_model method (#231)
1 parent 5887633 commit d531b2e

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

diabetes_regression/evaluate/evaluate_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from azureml.core import Run
2727
import argparse
2828
import traceback
29-
from util.model_helper import get_model_by_tag
29+
from util.model_helper import get_latest_model
3030

3131
run = Run.get_context()
3232

@@ -45,7 +45,7 @@
4545
# sources_dir = 'diabetes_regression'
4646
# path_to_util = os.path.join(".", sources_dir, "util")
4747
# sys.path.append(os.path.abspath(path_to_util)) # NOQA: E402
48-
# from model_helper import get_model_by_tag
48+
# from model_helper import get_latest_model
4949
# workspace_name = os.environ.get("WORKSPACE_NAME")
5050
# experiment_name = os.environ.get("EXPERIMENT_NAME")
5151
# resource_group = os.environ.get("RESOURCE_GROUP")
@@ -108,7 +108,7 @@
108108
firstRegistration = False
109109
tag_name = 'experiment_name'
110110

111-
model = get_model_by_tag(
111+
model = get_latest_model(
112112
model_name, tag_name, exp.name, ws)
113113

114114
if (model is not None):

diabetes_regression/util/model_helper.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@ def get_current_workspace() -> Workspace:
2222
return experiment.workspace
2323

2424

25-
def get_model_by_tag(
25+
def get_latest_model(
2626
model_name: str,
27-
tag_name: str,
28-
tag_value: str,
27+
tag_name: str = None,
28+
tag_value: str = None,
2929
aml_workspace: Workspace = None
3030
) -> AMLModel:
3131
"""
3232
Retrieves and returns the latest model from the workspace
33-
by its name and tag.
33+
by its name and (optional) tag.
3434
3535
Parameters:
3636
aml_workspace (Workspace): aml.core Workspace that the model lives.
3737
model_name (str): name of the model we are looking for
38-
tag (str): the tag value the model was registered under.
38+
(optional) tag (str): the tag value & name the model was registered under.
3939
4040
Return:
4141
A single aml model from the workspace that matches the name and tag.
@@ -44,37 +44,44 @@ def get_model_by_tag(
4444
# Validate params. cannot be None.
4545
if model_name is None:
4646
raise ValueError("model_name[:str] is required")
47-
if tag_name is None:
48-
raise ValueError("tag_name[:str] is required")
49-
if tag_value is None:
50-
raise ValueError("tag[:str] is required")
47+
5148
if aml_workspace is None:
49+
print("No workspace defined - using current experiment workspace.")
5250
aml_workspace = get_current_workspace()
5351

54-
# get model by tag.
55-
model_list = AMLModel.list(
56-
aml_workspace, name=model_name,
57-
tags=[[tag_name, tag_value]], latest=True
58-
)
52+
model_list = None
53+
tag_ext = ""
54+
55+
# Get lastest model
56+
# True: by name and tags
57+
if tag_name is not None and tag_value is not None:
58+
model_list = AMLModel.list(
59+
aml_workspace, name=model_name,
60+
tags=[[tag_name, tag_value]], latest=True
61+
)
62+
tag_ext = f"tag_name: {tag_name}, tag_value: {tag_value}."
63+
# False: Only by name
64+
else:
65+
model_list = AMLModel.list(
66+
aml_workspace, name=model_name, latest=True)
5967

6068
# latest should only return 1 model, but if it does,
6169
# then maybe sdk or source code changed.
62-
should_not_happen = ("Found more than one model "
63-
"for the latest with {{tag_name: {tag_name},"
64-
"tag_value: {tag_value}. "
65-
"Models found: {model_list}}}")\
66-
.format(tag_name=tag_name, tag_value=tag_value,
67-
model_list=model_list)
68-
no_model_found = ("No Model found with {{tag_name: {tag_name} ,"
69-
"tag_value: {tag_value}.}}")\
70-
.format(tag_name=tag_name, tag_value=tag_value)
70+
71+
# define the error messages
72+
too_many_model_message = ("Found more than one latest model. "
73+
f"Models found: {model_list}. "
74+
f"{tag_ext}")
75+
76+
no_model_found_message = (f"No Model found with name: {model_name}. "
77+
f"{tag_ext}")
7178

7279
if len(model_list) > 1:
73-
raise ValueError(should_not_happen)
80+
raise ValueError(too_many_model_message)
7481
if len(model_list) == 1:
7582
return model_list[0]
7683
else:
77-
print(no_model_found)
84+
print(no_model_found_message)
7885
return None
7986
except Exception:
8087
raise

ml_service/pipelines/diabetes_regression_verify_train_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from azureml.core import Run, Experiment, Workspace
55
from ml_service.util.env_variables import Env
6-
from diabetes_regression.util.model_helper import get_model_by_tag
6+
from diabetes_regression.util.model_helper import get_latest_model
77

88

99
def main():
@@ -53,7 +53,7 @@ def main():
5353

5454
try:
5555
tag_name = 'BuildId'
56-
model = get_model_by_tag(
56+
model = get_latest_model(
5757
model_name, tag_name, build_id, exp.workspace)
5858
if (model is not None):
5959
print("Model was registered for this build.")

0 commit comments

Comments
 (0)