1
1
import os
2
2
import sys
3
- import shutil
4
3
import asyncio
5
4
import tempfile
6
5
import aiofiles
7
6
from typing import Dict
8
7
from metaflow import Run
9
8
from metaflow .cli import start
10
9
from metaflow .click_api import MetaflowAPI
11
- from metaflow .subprocess_manager import SubprocessManager
10
+ from metaflow .subprocess_manager import SubprocessManager , CommandManager
12
11
13
12
14
13
async def read_from_file_when_ready (file_path ):
@@ -20,6 +19,28 @@ async def read_from_file_when_ready(file_path):
20
19
return content
21
20
22
21
22
+ class ExecutingRun (object ):
23
+ def __init__ (self , command_obj : CommandManager , run_obj : Run ) -> None :
24
+ self .command_obj = command_obj
25
+ self .run_obj = run_obj
26
+
27
+ def __getattr__ (self , name : str ):
28
+ if hasattr (self .run_obj , name ):
29
+ run_attr = getattr (self .run_obj , name )
30
+ if callable (run_attr ):
31
+ return lambda * args , ** kwargs : run_attr (* args , ** kwargs )
32
+ else :
33
+ return run_attr
34
+ elif hasattr (self .command_obj , name ):
35
+ command_attr = getattr (self .command_obj , name )
36
+ if callable (command_attr ):
37
+ return lambda * args , ** kwargs : command_attr (* args , ** kwargs )
38
+ else :
39
+ return command_attr
40
+ else :
41
+ raise AttributeError (f"Invalid attribute { name } " )
42
+
43
+
23
44
class Runner (object ):
24
45
def __init__ (
25
46
self ,
@@ -29,17 +50,14 @@ def __init__(
29
50
):
30
51
self .flow_file = flow_file
31
52
self .env_vars = os .environ .copy ().update (env )
32
- self .spm = SubprocessManager (env = self . env_vars )
53
+ self .spm = SubprocessManager ()
33
54
self .api = MetaflowAPI .from_cli (self .flow_file , start )
34
55
self .runner = self .api (** kwargs ).run
35
56
36
- def __enter__ (self ):
57
+ async def __aenter__ (self ):
37
58
return self
38
59
39
- async def tail_logs (self , stream = "stdout" ):
40
- await self .spm .get_logs (stream )
41
-
42
- async def run (self , blocking : bool = False , ** kwargs ):
60
+ async def run (self , ** kwargs ):
43
61
with tempfile .TemporaryDirectory () as temp_dir :
44
62
tfp_flow = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
45
63
tfp_run_id = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
@@ -48,20 +66,18 @@ async def run(self, blocking: bool = False, **kwargs):
48
66
run_id_file = tfp_run_id .name , flow_name_file = tfp_flow .name , ** kwargs
49
67
)
50
68
51
- process = await self .spm .run_command ([ sys . executable , * command . split ()])
52
-
53
- if blocking :
54
- await process . wait ( )
69
+ command_id = await self .spm .run_command (
70
+ [ sys . executable , * command ], env = self . env_vars
71
+ )
72
+ command_obj = self . spm . get ( command_id )
55
73
56
74
flow_name = await read_from_file_when_ready (tfp_flow .name )
57
75
run_id = await read_from_file_when_ready (tfp_run_id .name )
58
76
59
77
pathspec_components = (flow_name , run_id )
60
78
run_object = Run ("/" .join (pathspec_components ), _namespace_check = False )
61
79
62
- self .run = run_object
63
-
64
- return run_object
80
+ return ExecutingRun (command_obj , run_object )
65
81
66
- def __exit__ (self , exc_type , exc_value , traceback ):
67
- shutil . rmtree ( self .spm .temp_dir , ignore_errors = True )
82
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
83
+ await self .spm .cleanup ( )
0 commit comments