|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import shutil |
| 4 | +import asyncio |
| 5 | +import tempfile |
| 6 | +import aiofiles |
| 7 | +from typing import List |
| 8 | +from asyncio.queues import Queue |
| 9 | + |
| 10 | + |
| 11 | +class SubprocessManager(object): |
| 12 | + def __init__(self, env=None, cwd=None): |
| 13 | + if env is None: |
| 14 | + env = os.environ.copy() |
| 15 | + self.env = env |
| 16 | + |
| 17 | + if cwd is None: |
| 18 | + cwd = os.getcwd() |
| 19 | + self.cwd = cwd |
| 20 | + |
| 21 | + self.process = None |
| 22 | + self.run_command_called = False |
| 23 | + |
| 24 | + async def get_logs(self, stream="stdout"): |
| 25 | + if self.run_command_called is False: |
| 26 | + raise ValueError("No command run yet to get the logs for...") |
| 27 | + if stream == "stdout": |
| 28 | + stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue)) |
| 29 | + await stdout_task |
| 30 | + elif stream == "stderr": |
| 31 | + stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue)) |
| 32 | + await stderr_task |
| 33 | + else: |
| 34 | + raise ValueError( |
| 35 | + f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}" |
| 36 | + ) |
| 37 | + |
| 38 | + async def stream_logs_to_queue(self, logfile, queue, process): |
| 39 | + async with aiofiles.open(logfile, "r") as f: |
| 40 | + while True: |
| 41 | + if process.returncode is None: |
| 42 | + # process is still running |
| 43 | + line = await f.readline() |
| 44 | + if not line: |
| 45 | + continue |
| 46 | + await queue.put(line.strip()) |
| 47 | + elif process.returncode == 0: |
| 48 | + # insert an indicator that no more items |
| 49 | + # will be inserted into the queue |
| 50 | + await queue.put(None) |
| 51 | + break |
| 52 | + elif process.returncode != 0: |
| 53 | + # insert an indicator that no more items |
| 54 | + # will be inserted into the queue |
| 55 | + await queue.put(None) |
| 56 | + raise Exception("Ran into an issue...") |
| 57 | + |
| 58 | + async def consume_queue(self, queue: Queue): |
| 59 | + while True: |
| 60 | + item = await queue.get() |
| 61 | + # break out of loop when we get the `indicator` |
| 62 | + if item is None: |
| 63 | + break |
| 64 | + print(item) |
| 65 | + queue.task_done() |
| 66 | + |
| 67 | + async def run_command(self, command: List[str]): |
| 68 | + self.temp_dir = tempfile.mkdtemp() |
| 69 | + stdout_logfile = os.path.join(self.temp_dir, "stdout.log") |
| 70 | + stderr_logfile = os.path.join(self.temp_dir, "stderr.log") |
| 71 | + |
| 72 | + self.stdout_queue = Queue() |
| 73 | + self.stderr_queue = Queue() |
| 74 | + |
| 75 | + try: |
| 76 | + # returns when subprocess has been started, not |
| 77 | + # when it is finished... |
| 78 | + self.process = await asyncio.create_subprocess_exec( |
| 79 | + *command, |
| 80 | + cwd=self.cwd, |
| 81 | + env=self.env, |
| 82 | + stdout=await aiofiles.open(stdout_logfile, "w"), |
| 83 | + stderr=await aiofiles.open(stderr_logfile, "w"), |
| 84 | + ) |
| 85 | + |
| 86 | + self.stdout_task = asyncio.create_task( |
| 87 | + self.stream_logs_to_queue( |
| 88 | + stdout_logfile, self.stdout_queue, self.process |
| 89 | + ) |
| 90 | + ) |
| 91 | + self.stderr_task = asyncio.create_task( |
| 92 | + self.stream_logs_to_queue( |
| 93 | + stderr_logfile, self.stderr_queue, self.process |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + self.run_command_called = True |
| 98 | + return self.process |
| 99 | + except Exception as e: |
| 100 | + print(f"Error starting subprocess: {e}") |
| 101 | + # Clean up temp files if process fails to start |
| 102 | + shutil.rmtree(self.temp_dir, ignore_errors=True) |
| 103 | + |
| 104 | + |
| 105 | +async def main(): |
| 106 | + flow_file = "../try.py" |
| 107 | + from metaflow.cli import start |
| 108 | + from metaflow.click_api import MetaflowAPI |
| 109 | + |
| 110 | + api = MetaflowAPI.from_cli(flow_file, start) |
| 111 | + command = api().run(alpha=5) |
| 112 | + cmd = [sys.executable, *command.split()] |
| 113 | + |
| 114 | + spm = SubprocessManager() |
| 115 | + process = await spm.run_command(cmd) |
| 116 | + # await process.wait() |
| 117 | + # print(process.returncode) |
| 118 | + print("will print logs after 15 secs, flow has ended by then...") |
| 119 | + await asyncio.sleep(15) |
| 120 | + print("done waiting...") |
| 121 | + await spm.get_logs(stream="stdout") |
| 122 | + |
| 123 | + |
| 124 | +if __name__ == "__main__": |
| 125 | + asyncio.run(main()) |
0 commit comments