Skip to content

Commit 352ebbe

Browse files
authored
Model registration tags come from parameters.json (#237)
1 parent 2d54311 commit 352ebbe

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

diabetes_regression/evaluate/evaluate_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
"--model_name",
8484
type=str,
8585
help="Name of the Model",
86-
default="sklearn_regression_model.pkl",
86+
default="diabetes_model.pkl",
8787
)
8888

8989
parser.add_argument(

diabetes_regression/parameters.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
"evaluation":
77
{
88

9+
},
10+
"registration":
11+
{
12+
"tags": ["mse"]
913
},
1014
"scoring":
1115
{

diabetes_regression/register/register_model.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
2424
POSSIBILITY OF SUCH DAMAGE.
2525
"""
26+
import json
2627
import os
2728
import sys
2829
import argparse
@@ -69,8 +70,9 @@ def main():
6970
"--model_name",
7071
type=str,
7172
help="Name of the Model",
72-
default="sklearn_regression_model.pkl",
73+
default="diabetes_model.pkl",
7374
)
75+
7476
parser.add_argument(
7577
"--step_input",
7678
type=str,
@@ -85,40 +87,58 @@ def main():
8587
model_name = args.model_name
8688
model_path = args.step_input
8789

90+
print("Getting registration parameters")
91+
92+
# Load the registration parameters from the parameters file
93+
with open("parameters.json") as f:
94+
pars = json.load(f)
95+
try:
96+
register_args = pars["registration"]
97+
except KeyError:
98+
print("Could not load registration values from file")
99+
register_args = {"tags": []}
100+
101+
model_tags = {}
102+
for tag in register_args["tags"]:
103+
try:
104+
mtag = run.parent.get_metrics()[tag]
105+
model_tags[tag] = mtag
106+
except KeyError:
107+
print(f"Could not find {tag} metric on parent run.")
108+
88109
# load the model
89110
print("Loading model from " + model_path)
90111
model_file = os.path.join(model_path, model_name)
91112
model = joblib.load(model_file)
92-
model_mse = run.parent.get_metrics()["mse"]
93113
parent_tags = run.parent.get_tags()
94114
try:
95115
build_id = parent_tags["BuildId"]
96116
except KeyError:
97117
build_id = None
98118
print("BuildId tag not found on parent run.")
99-
print("Tags present: {parent_tags}")
119+
print(f"Tags present: {parent_tags}")
100120
try:
101121
build_uri = parent_tags["BuildUri"]
102122
except KeyError:
103123
build_uri = None
104124
print("BuildUri tag not found on parent run.")
105-
print("Tags present: {parent_tags}")
125+
print(f"Tags present: {parent_tags}")
106126

107127
if (model is not None):
108128
dataset_id = parent_tags["dataset_id"]
109129
if (build_id is None):
110130
register_aml_model(
111131
model_file,
112132
model_name,
113-
model_mse,
133+
model_tags,
114134
exp,
115135
run_id,
116136
dataset_id)
117137
elif (build_uri is None):
118138
register_aml_model(
119139
model_file,
120140
model_name,
121-
model_mse,
141+
model_tags,
122142
exp,
123143
run_id,
124144
dataset_id,
@@ -127,7 +147,7 @@ def main():
127147
register_aml_model(
128148
model_file,
129149
model_name,
130-
model_mse,
150+
model_tags,
131151
exp,
132152
run_id,
133153
dataset_id,
@@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id):
152172
def register_aml_model(
153173
model_path,
154174
model_name,
155-
model_mse,
175+
model_tags,
156176
exp,
157177
run_id,
158178
dataset_id,
@@ -162,8 +182,8 @@ def register_aml_model(
162182
try:
163183
tagsValue = {"area": "diabetes_regression",
164184
"run_id": run_id,
165-
"experiment_name": exp.name,
166-
"mse": model_mse}
185+
"experiment_name": exp.name}
186+
tagsValue.update(model_tags)
167187
if (build_id != 'none'):
168188
model_already_registered(model_name, exp, run_id)
169189
tagsValue["BuildId"] = build_id

diabetes_regression/training/train_aml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def main():
5555
"--model_name",
5656
type=str,
5757
help="Name of the Model",
58-
default="sklearn_regression_model.pkl",
58+
default="diabetes_model.pkl",
5959
)
6060

6161
parser.add_argument(

0 commit comments

Comments
 (0)