23
23
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
24
24
POSSIBILITY OF SUCH DAMAGE.
25
25
"""
26
+ import json
26
27
import os
27
28
import sys
28
29
import argparse
@@ -69,8 +70,9 @@ def main():
69
70
"--model_name" ,
70
71
type = str ,
71
72
help = "Name of the Model" ,
72
- default = "sklearn_regression_model .pkl" ,
73
+ default = "diabetes_model .pkl" ,
73
74
)
75
+
74
76
parser .add_argument (
75
77
"--step_input" ,
76
78
type = str ,
@@ -85,40 +87,58 @@ def main():
85
87
model_name = args .model_name
86
88
model_path = args .step_input
87
89
90
+ print ("Getting registration parameters" )
91
+
92
+ # Load the registration parameters from the parameters file
93
+ with open ("parameters.json" ) as f :
94
+ pars = json .load (f )
95
+ try :
96
+ register_args = pars ["registration" ]
97
+ except KeyError :
98
+ print ("Could not load registration values from file" )
99
+ register_args = {"tags" : []}
100
+
101
+ model_tags = {}
102
+ for tag in register_args ["tags" ]:
103
+ try :
104
+ mtag = run .parent .get_metrics ()[tag ]
105
+ model_tags [tag ] = mtag
106
+ except KeyError :
107
+ print (f"Could not find { tag } metric on parent run." )
108
+
88
109
# load the model
89
110
print ("Loading model from " + model_path )
90
111
model_file = os .path .join (model_path , model_name )
91
112
model = joblib .load (model_file )
92
- model_mse = run .parent .get_metrics ()["mse" ]
93
113
parent_tags = run .parent .get_tags ()
94
114
try :
95
115
build_id = parent_tags ["BuildId" ]
96
116
except KeyError :
97
117
build_id = None
98
118
print ("BuildId tag not found on parent run." )
99
- print ("Tags present: {parent_tags}" )
119
+ print (f "Tags present: { parent_tags } " )
100
120
try :
101
121
build_uri = parent_tags ["BuildUri" ]
102
122
except KeyError :
103
123
build_uri = None
104
124
print ("BuildUri tag not found on parent run." )
105
- print ("Tags present: {parent_tags}" )
125
+ print (f "Tags present: { parent_tags } " )
106
126
107
127
if (model is not None ):
108
128
dataset_id = parent_tags ["dataset_id" ]
109
129
if (build_id is None ):
110
130
register_aml_model (
111
131
model_file ,
112
132
model_name ,
113
- model_mse ,
133
+ model_tags ,
114
134
exp ,
115
135
run_id ,
116
136
dataset_id )
117
137
elif (build_uri is None ):
118
138
register_aml_model (
119
139
model_file ,
120
140
model_name ,
121
- model_mse ,
141
+ model_tags ,
122
142
exp ,
123
143
run_id ,
124
144
dataset_id ,
@@ -127,7 +147,7 @@ def main():
127
147
register_aml_model (
128
148
model_file ,
129
149
model_name ,
130
- model_mse ,
150
+ model_tags ,
131
151
exp ,
132
152
run_id ,
133
153
dataset_id ,
@@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id):
152
172
def register_aml_model (
153
173
model_path ,
154
174
model_name ,
155
- model_mse ,
175
+ model_tags ,
156
176
exp ,
157
177
run_id ,
158
178
dataset_id ,
@@ -162,8 +182,8 @@ def register_aml_model(
162
182
try :
163
183
tagsValue = {"area" : "diabetes_regression" ,
164
184
"run_id" : run_id ,
165
- "experiment_name" : exp .name ,
166
- "mse" : model_mse }
185
+ "experiment_name" : exp .name }
186
+ tagsValue . update ( model_tags )
167
187
if (build_id != 'none' ):
168
188
model_already_registered (model_name , exp , run_id )
169
189
tagsValue ["BuildId" ] = build_id
0 commit comments