Skip to content

Commit 6f725aa

Browse files
committed
subprocess manager
1 parent a3919a0 commit 6f725aa

File tree

2 files changed

+152
-27
lines changed

2 files changed

+152
-27
lines changed

metaflow/metaflow_runner.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,67 @@
11
import os
22
import sys
3-
import time
3+
import shutil
4+
import asyncio
45
import tempfile
5-
import subprocess
6+
import aiofiles
67
from typing import Dict
78
from metaflow import Run
89
from metaflow.cli import start
910
from metaflow.click_api import MetaflowAPI
11+
from metaflow.subprocess_manager import SubprocessManager
1012

1113

12-
def cli_runner(command: str, env_vars: Dict):
13-
process = subprocess.Popen(
14-
[sys.executable, *command.split()],
15-
stdout=subprocess.PIPE,
16-
stderr=subprocess.PIPE,
17-
env=env_vars,
18-
)
19-
return process
20-
21-
22-
def read_from_file_when_ready(file_pointer):
23-
content = file_pointer.read().decode()
24-
while not content:
25-
time.sleep(0.1)
26-
content = file_pointer.read().decode()
27-
return content
14+
async def read_from_file_when_ready(file_path):
15+
async with aiofiles.open(file_path, "r") as file_pointer:
16+
content = await file_pointer.read()
17+
while not content:
18+
await asyncio.sleep(0.1)
19+
content = await file_pointer.read()
20+
return content
2821

2922

3023
class Runner(object):
3124
def __init__(
3225
self,
3326
flow_file: str,
27+
env: Dict = {},
3428
**kwargs,
3529
):
3630
self.flow_file = flow_file
31+
self.env_vars = os.environ.copy().update(env)
32+
self.spm = SubprocessManager(env=self.env_vars)
3733
self.api = MetaflowAPI.from_cli(self.flow_file, start)
3834
self.runner = self.api(**kwargs).run
3935

4036
def __enter__(self):
4137
return self
4238

43-
def run(self, blocking: bool = False, **kwargs):
44-
env_vars = os.environ.copy()
39+
async def tail_logs(self, stream="stdout"):
40+
await self.spm.get_logs(stream)
4541

42+
async def run(self, blocking: bool = False, **kwargs):
4643
with tempfile.TemporaryDirectory() as temp_dir:
47-
tfp_flow_name = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
44+
tfp_flow = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
4845
tfp_run_id = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
4946

5047
command = self.runner(
51-
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow_name.name, **kwargs
48+
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow.name, **kwargs
5249
)
5350

54-
process = cli_runner(command, env_vars)
51+
process = await self.spm.run_command([sys.executable, *command.split()])
52+
5553
if blocking:
56-
process.wait()
54+
await process.wait()
5755

58-
flow_name = read_from_file_when_ready(tfp_flow_name)
59-
run_id = read_from_file_when_ready(tfp_run_id)
56+
flow_name = await read_from_file_when_ready(tfp_flow.name)
57+
run_id = await read_from_file_when_ready(tfp_run_id.name)
6058

6159
pathspec_components = (flow_name, run_id)
6260
run_object = Run("/".join(pathspec_components), _namespace_check=False)
6361

62+
self.run = run_object
63+
6464
return run_object
6565

6666
def __exit__(self, exc_type, exc_value, traceback):
67-
pass
67+
shutil.rmtree(self.spm.temp_dir, ignore_errors=True)

metaflow/subprocess_manager.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
import sys
3+
import shutil
4+
import asyncio
5+
import tempfile
6+
import aiofiles
7+
from typing import List
8+
from asyncio.queues import Queue
9+
10+
11+
class SubprocessManager(object):
12+
def __init__(self, env=None, cwd=None):
13+
if env is None:
14+
env = os.environ.copy()
15+
self.env = env
16+
17+
if cwd is None:
18+
cwd = os.getcwd()
19+
self.cwd = cwd
20+
21+
self.process = None
22+
self.run_command_called = False
23+
24+
async def get_logs(self, stream="stdout"):
25+
if self.run_command_called is False:
26+
raise ValueError("No command run yet to get the logs for...")
27+
if stream == "stdout":
28+
stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue))
29+
await stdout_task
30+
elif stream == "stderr":
31+
stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue))
32+
await stderr_task
33+
else:
34+
raise ValueError(
35+
f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}"
36+
)
37+
38+
async def stream_logs_to_queue(self, logfile, queue, process):
39+
async with aiofiles.open(logfile, "r") as f:
40+
while True:
41+
if process.returncode is None:
42+
# process is still running
43+
line = await f.readline()
44+
if not line:
45+
continue
46+
await queue.put(line.strip())
47+
elif process.returncode == 0:
48+
# insert an indicator that no more items
49+
# will be inserted into the queue
50+
await queue.put(None)
51+
break
52+
elif process.returncode != 0:
53+
# insert an indicator that no more items
54+
# will be inserted into the queue
55+
await queue.put(None)
56+
raise Exception("Ran into an issue...")
57+
58+
async def consume_queue(self, queue: Queue):
59+
while True:
60+
item = await queue.get()
61+
# break out of loop when we get the `indicator`
62+
if item is None:
63+
break
64+
print(item)
65+
queue.task_done()
66+
67+
async def run_command(self, command: List[str]):
68+
self.temp_dir = tempfile.mkdtemp()
69+
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
70+
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")
71+
72+
self.stdout_queue = Queue()
73+
self.stderr_queue = Queue()
74+
75+
try:
76+
# returns when subprocess has been started, not
77+
# when it is finished...
78+
self.process = await asyncio.create_subprocess_exec(
79+
*command,
80+
cwd=self.cwd,
81+
env=self.env,
82+
stdout=await aiofiles.open(stdout_logfile, "w"),
83+
stderr=await aiofiles.open(stderr_logfile, "w"),
84+
)
85+
86+
self.stdout_task = asyncio.create_task(
87+
self.stream_logs_to_queue(
88+
stdout_logfile, self.stdout_queue, self.process
89+
)
90+
)
91+
self.stderr_task = asyncio.create_task(
92+
self.stream_logs_to_queue(
93+
stderr_logfile, self.stderr_queue, self.process
94+
)
95+
)
96+
97+
self.run_command_called = True
98+
return self.process
99+
except Exception as e:
100+
print(f"Error starting subprocess: {e}")
101+
# Clean up temp files if process fails to start
102+
shutil.rmtree(self.temp_dir, ignore_errors=True)
103+
104+
105+
async def main():
106+
flow_file = "../try.py"
107+
from metaflow.cli import start
108+
from metaflow.click_api import MetaflowAPI
109+
110+
api = MetaflowAPI.from_cli(flow_file, start)
111+
command = api().run(alpha=5)
112+
cmd = [sys.executable, *command.split()]
113+
114+
spm = SubprocessManager()
115+
process = await spm.run_command(cmd)
116+
# await process.wait()
117+
# print(process.returncode)
118+
print("will print logs after 15 secs, flow has ended by then...")
119+
await asyncio.sleep(15)
120+
print("done waiting...")
121+
await spm.get_logs(stream="stdout")
122+
123+
124+
if __name__ == "__main__":
125+
asyncio.run(main())

0 commit comments

Comments
 (0)