Skip to content

Commit bafe765

Browse files
committed
simpler subprocess manager
1 parent 63acedd commit bafe765

File tree

2 files changed

+69
-67
lines changed

2 files changed

+69
-67
lines changed

metaflow/cli.py

-1
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,6 @@ def write_file(file_path, content):
870870
if file_path is not None:
871871
with open(file_path, "w") as f:
872872
f.write(str(content))
873-
f.close()
874873

875874

876875
def before_run(obj, tags, decospecs):

metaflow/subprocess_manager.py

+69-66
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import os
22
import sys
3+
import time
34
import shutil
45
import asyncio
56
import tempfile
6-
import aiofiles
77
from typing import List
8-
from asyncio.queues import Queue
98

109

1110
class SubprocessManager(object):
@@ -20,87 +19,77 @@ def __init__(self, env=None, cwd=None):
2019

2120
self.process = None
2221
self.run_command_called = False
23-
24-
async def get_logs(self, stream="stdout"):
25-
if self.run_command_called is False:
26-
raise ValueError("No command run yet to get the logs for...")
27-
if stream == "stdout":
28-
stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue))
29-
await stdout_task
30-
elif stream == "stderr":
31-
stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue))
32-
await stderr_task
33-
else:
34-
raise ValueError(
35-
f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}"
36-
)
37-
38-
async def stream_logs_to_queue(self, logfile, queue, process):
39-
async with aiofiles.open(logfile, "r") as f:
40-
while True:
41-
if process.returncode is None:
42-
# process is still running
43-
line = await f.readline()
44-
if not line:
45-
continue
46-
await queue.put(line.strip())
47-
elif process.returncode == 0:
48-
# insert an indicator that no more items
49-
# will be inserted into the queue
50-
await queue.put(None)
51-
break
52-
elif process.returncode != 0:
53-
# insert an indicator that no more items
54-
# will be inserted into the queue
55-
await queue.put(None)
56-
raise Exception("Ran into an issue...")
57-
58-
async def consume_queue(self, queue: Queue):
59-
while True:
60-
item = await queue.get()
61-
# break out of loop when we get the `indicator`
62-
if item is None:
63-
break
64-
print(item)
65-
queue.task_done()
22+
self.log_files = {}
23+
self.process_dict = {}
6624

6725
async def run_command(self, command: List[str]):
6826
self.temp_dir = tempfile.mkdtemp()
6927
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
7028
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")
7129

72-
self.stdout_queue = Queue()
73-
self.stderr_queue = Queue()
74-
7530
try:
7631
# returns when subprocess has been started, not
7732
# when it is finished...
7833
self.process = await asyncio.create_subprocess_exec(
7934
*command,
8035
cwd=self.cwd,
8136
env=self.env,
82-
stdout=await aiofiles.open(stdout_logfile, "w"),
83-
stderr=await aiofiles.open(stderr_logfile, "w"),
37+
stdout=open(stdout_logfile, "w"),
38+
stderr=open(stderr_logfile, "w"),
8439
)
8540

86-
self.stdout_task = asyncio.create_task(
87-
self.stream_logs_to_queue(
88-
stdout_logfile, self.stdout_queue, self.process
89-
)
90-
)
91-
self.stderr_task = asyncio.create_task(
92-
self.stream_logs_to_queue(
93-
stderr_logfile, self.stderr_queue, self.process
94-
)
95-
)
41+
self.log_files["stdout"] = stdout_logfile
42+
self.log_files["stderr"] = stderr_logfile
9643

9744
self.run_command_called = True
9845
return self.process
9946
except Exception as e:
10047
print(f"Error starting subprocess: {e}")
101-
# Clean up temp files if process fails to start
48+
await self.cleanup()
49+
50+
async def stream_logs(self, stream):
51+
if self.run_command_called is False:
52+
raise ValueError("No command run yet to get the logs for...")
53+
54+
if stream not in self.log_files:
55+
raise ValueError(f"No log file found for {stream}")
56+
57+
log_file = self.log_files[stream]
58+
59+
with open(log_file, mode="r") as f:
60+
last_position = self.process_dict.get(stream, 0)
61+
f.seek(last_position)
62+
63+
while True:
64+
line = f.readline()
65+
if not line:
66+
break
67+
print(line.strip())
68+
69+
self.process_dict[stream] = f.tell()
70+
71+
async def get_logs(self, stream="stdout", delay=0.1):
72+
while self.process.returncode is None:
73+
await self.stream_logs(stream)
74+
await asyncio.sleep(delay)
75+
76+
async def cleanup(self):
77+
if hasattr(self, "temp_dir"):
10278
shutil.rmtree(self.temp_dir, ignore_errors=True)
10379

80+
async def kill_process(self, timeout=5):
81+
if self.process is not None:
82+
if self.process.returncode is None:
83+
self.process.terminate()
84+
try:
85+
await asyncio.wait_for(self.process.wait(), timeout)
86+
except asyncio.TimeoutError:
87+
self.process.kill()
88+
else:
89+
print("Process has already terminated.")
90+
else:
91+
print("No process to kill.")
92+
10493

10594
async def main():
10695
flow_file = "../try.py"
@@ -114,11 +103,25 @@ async def main():
114103
spm = SubprocessManager()
115104
process = await spm.run_command(cmd)
116105
# await process.wait()
117-
# print(process.returncode)
118-
print("will print logs after 15 secs, flow has ended by then...")
119-
await asyncio.sleep(15)
120-
print("done waiting...")
121-
await spm.get_logs(stream="stdout")
106+
print(process.returncode)
107+
print(process)
108+
109+
# print("kill after 2 seconds...get logs upto the point of killing...")
110+
# await asyncio.wait([
111+
# asyncio.create_task(spm.get_logs(stream="stdout")),
112+
# asyncio.create_task(asyncio.sleep(2)),
113+
# ], return_when="FIRST_COMPLETED")
114+
# await spm.kill_process()
115+
# print("done...")
116+
117+
# print("will print logs after 15 secs, flow has ended by then...")
118+
# time.sleep(15)
119+
# print("done waiting...")
120+
# await spm.get_logs(stream="stdout")
121+
# await spm.cleanup()
122+
123+
# await spm.get_logs(stream="stdout")
124+
# await spm.cleanup()
122125

123126

124127
if __name__ == "__main__":

0 commit comments

Comments
 (0)