Skip to content

Commit 66c3d30

Browse files
committed
model search with model upload to best run
1 parent 2c4f7db commit 66c3d30

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

model_search_upload.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,7 @@ def uploat_to_mlflow(temp_dir, **context):
6969
print(f"Experiment {experiment_name} was not found, creating new")
7070
experiment_id = client.create_experiment(experiment_name)
7171

72-
run = client.create_run(experiment_id)
73-
print(f"Uploading to experiment {experiment_name}/{experiment_id}/{run.info.run_id}")
74-
75-
print("Uploading model")
76-
client.log_artifact(
77-
run_id=run.info.run_id,
78-
local_path=os.path.join(temp_dir, 'model.dat'),
79-
artifact_path="model",
80-
)
72+
print(f"Uploading to experiment {experiment_name}/{experiment_id}")
8173

8274
print("Uploading model search results")
8375
df = pd.read_csv(os.path.join(temp_dir, 'pd.csv'), index_col=0)
@@ -86,7 +78,7 @@ def uploat_to_mlflow(temp_dir, **context):
8678
metrics=['mean_test_score', 'mean_fit_time']
8779

8880
for i, p in enumerate(dct['params'].values()):
89-
with mlflow.start_run(experiment_id=experiment_id):
81+
with mlflow.start_run(experiment_id=experiment_id) as run:
9082
p = json.loads(p.replace('\'', '"'))
9183
for parname, parvalue in p.items():
9284
mlflow.log_param(key=parname, value=parvalue)
@@ -98,6 +90,15 @@ def uploat_to_mlflow(temp_dir, **context):
9890
print(f"Logging metric {m} {dct[m][i]}")
9991
mlflow.log_metric(key=m, value=dct[m][i])
10092

93+
if dct['rank_test_score'][i]==1:
94+
print('This is the best model')
95+
print("Uploading model to run: ", run.info.run_id)
96+
mlflow.log_artifact(
97+
local_path=os.path.join(temp_dir, 'model.dat'),
98+
artifact_path="model",
99+
)
100+
101+
101102
#clean up
102103
shutil.rmtree(temp_dir)
103104

0 commit comments

Comments
 (0)