-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmlflow_stagein.py
81 lines (60 loc) · 2.58 KB
/
mlflow_stagein.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import tempfile
import pendulum
from airflow.decorators import dag, task
from airflow.models import Variable
from airflow.models.param import Param
from airflow.operators.python import PythonOperator
from decors import get_connection, remove, setup
from utils import file_exist, copy_streams
@dag(
schedule=None,
start_date=pendulum.today("UTC"),
tags=["example", "model repo"],
params={
"location": Param("/tmp/", type="string", description="target location to copy the model into"),
"vault_id": Param(default="", type="string"),
"host": Param(default="", type="string"),
"port": Param(type="integer", default=22),
"login": Param(default="", type="string"),
"mlflow_runid": Param(default="", type="string", description="run from which model should be staged-in"),
"mlflow_modelpath": Param(type="string", default="model/model.pkl")
},
)
def model_stagein():
@task()
def copy_model(connection_id, **context):
from utils import get_mlflow_client
import shutil
parms = context["params"]
location = parms["location"]
run_id = parms["mlflow_runid"]
model_path = parms["mlflow_modelpath"]
target = Variable.get("working_dir", default_var="/tmp/")
temp_dir = tempfile.mkdtemp(dir=target)
client = get_mlflow_client()
ret = client.download_artifacts(run_id=run_id, path=model_path, dst_path=temp_dir)
if ret:
print("Model dowloaded: ", ret)
ssh_hook = get_connection(conn_id=connection_id, **context)
clt = ssh_hook.get_conn()
sftp_client = ssh_hook.get_conn().open_sftp()
with open(ret, "rb") as sr:
target_name = os.path.join(location, os.path.basename(ret))
print(f"Uploading: {ret}-->{target_name}")
if file_exist(sftp=sftp_client, name=target_name):
print(target_name," exists. Overwritting.")
clt.exec_command(command=f"touch {target_name}")
with sftp_client.open(target_name, "wb") as tr:
tr.set_pipelined(pipelined=True)
copy_streams(inp=sr, outp=tr)
shutil.rmtree(path=temp_dir)
return target_name
setup_task = PythonOperator(python_callable=setup, task_id="setup_connection")
a_id = setup_task.output["return_value"]
cpy = copy_model(connection_id=a_id)
cleanup_task = PythonOperator(
python_callable=remove, op_kwargs={"conn_id": a_id}, task_id="cleanup"
)
setup_task >> cpy >> cleanup_task
dag = model_stagein()