Skip to content

Commit c681f71

Browse files
committed
suggested improvements
1 parent 9c95627 commit c681f71

File tree

1 file changed

+75
-45
lines changed

1 file changed

+75
-45
lines changed

metaflow/subprocess_manager.py

+75-45
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,112 @@
11
import os
22
import sys
3-
import time
43
import signal
54
import shutil
6-
import hashlib
75
import asyncio
86
import tempfile
9-
from typing import List
10-
11-
12-
def hash_command_invocation(command: List[str]):
13-
concatenated_string = "".join(command)
14-
current_time = str(time.time())
15-
concatenated_string += current_time
16-
hash_object = hashlib.sha256(concatenated_string.encode())
17-
return hash_object.hexdigest()
7+
from typing import List, Dict, Optional, Callable
188

199

2010
class LogReadTimeoutError(Exception):
11+
"""Exception raised when reading logs times out."""
12+
2113
pass
2214

2315

2416
class SubprocessManager(object):
17+
"""A manager for subprocesses."""
18+
2519
def __init__(self):
26-
self.commands = {}
20+
self.commands: Dict[int, CommandManager] = {}
2721

28-
async def __aenter__(self):
22+
async def __aenter__(self) -> "SubprocessManager":
2923
return self
3024

3125
async def __aexit__(self, exc_type, exc_value, traceback):
3226
await self.cleanup()
3327

34-
async def run_command(self, command: List[str], env=None, cwd=None):
35-
command_id = hash_command_invocation(command)
36-
self.commands[command_id] = CommandManager(command, env, cwd)
37-
await self.commands[command_id].run()
38-
return command_id
28+
async def run_command(
29+
self,
30+
command: List[str],
31+
env: Optional[Dict[str, str]] = None,
32+
cwd: Optional[str] = None,
33+
) -> int:
34+
"""Run a command asynchronously and return its process ID."""
3935

40-
def get(self, command_id: str) -> "CommandManager":
41-
return self.commands.get(command_id, None)
36+
command_obj = CommandManager(command, env, cwd)
37+
process = await command_obj.run()
38+
self.commands[process.pid] = command_obj
39+
return process.pid
4240

43-
async def cleanup(self):
44-
for _, v in self.commands.items():
41+
def get(self, pid: int) -> "CommandManager":
42+
"""Get the CommandManager object for a given process ID."""
43+
44+
return self.commands.get(pid, None)
45+
46+
async def cleanup(self) -> None:
47+
"""Clean up log files for all running subprocesses."""
48+
49+
for v in self.commands.values():
4550
await v.cleanup()
4651

4752

4853
class CommandManager(object):
49-
def __init__(self, command: List[str], env=None, cwd=None):
50-
self.command = command
54+
"""A manager for an individual subprocess."""
5155

52-
if env is None:
53-
env = os.environ.copy()
54-
self.env = env
56+
def __init__(
57+
self,
58+
command: List[str],
59+
env: Optional[Dict[str, str]] = None,
60+
cwd: Optional[str] = None,
61+
):
62+
self.command = command
5563

56-
if cwd is None:
57-
cwd = os.getcwd()
58-
self.cwd = cwd
64+
self.env = env if env is not None else os.environ.copy()
65+
self.cwd = cwd if cwd is not None else os.getcwd()
5966

6067
self.process = None
61-
self.run_called = False
62-
self.log_files = {}
68+
self.run_called: bool = False
69+
self.log_files: Dict[str, str] = {}
6370

6471
signal.signal(signal.SIGINT, self.handle_sigint)
6572

66-
async def __aenter__(self):
73+
async def __aenter__(self) -> "CommandManager":
6774
return self
6875

6976
async def __aexit__(self, exc_type, exc_value, traceback):
7077
await self.cleanup()
7178

7279
def handle_sigint(self, signum, frame):
80+
"""Handle the SIGINT signal."""
81+
7382
print("SIGINT received.")
7483
asyncio.create_task(self.kill())
7584

76-
async def wait(self, timeout=None, stream=None):
85+
async def wait(
86+
self, timeout: Optional[float] = None, stream: Optional[str] = None
87+
) -> None:
88+
"""Wait for the subprocess to finish, optionally with a timeout and optionally streaming its output."""
89+
7790
if timeout is None:
7891
if stream is None:
7992
await self.process.wait()
8093
else:
8194
await self.emit_logs(stream)
8295
else:
83-
tasks = [asyncio.create_task(asyncio.sleep(timeout))]
84-
if stream is None:
85-
tasks.append(asyncio.create_task(self.process.wait()))
86-
else:
87-
tasks.append(asyncio.create_task(self.emit_logs(stream)))
88-
89-
await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
96+
try:
97+
if stream is None:
98+
await asyncio.wait_for(self.process.wait(), timeout)
99+
else:
100+
await asyncio.wait_for(self.emit_logs(stream), timeout)
101+
except asyncio.TimeoutError:
102+
command_string = " ".join(self.command)
103+
print(
104+
f"Timeout: The process: '{command_string}' didn't complete within {timeout} seconds."
105+
)
90106

91107
async def run(self):
108+
"""Run the subprocess, streaming the logs to temporary files"""
109+
92110
self.temp_dir = tempfile.mkdtemp()
93111
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
94112
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")
@@ -114,14 +132,20 @@ async def run(self):
114132
await self.cleanup()
115133

116134
async def stream_logs(
117-
self, stream, position=None, timeout_per_line=None, log_write_delay=0.01
135+
self,
136+
stream: str,
137+
position: Optional[int] = None,
138+
timeout_per_line: Optional[float] = None,
139+
log_write_delay: float = 0.01,
118140
):
141+
"""Stream logs from the subprocess using the log files"""
142+
119143
if self.run_called is False:
120144
raise ValueError("No command run yet to get the logs for...")
121145

122146
if stream not in self.log_files:
123147
raise ValueError(
124-
f"No log file found for {stream}, valid values are: {list(self.log_files.keys())}"
148+
f"No log file found for '{stream}', valid values are: {list(self.log_files.keys())}"
125149
)
126150

127151
log_file = self.log_files[stream]
@@ -161,15 +185,21 @@ async def stream_logs(
161185
position = f.tell()
162186
yield position, line.strip()
163187

164-
async def emit_logs(self, stream="stdout", custom_logger=print):
188+
async def emit_logs(self, stream: str = "stdout", custom_logger: Callable = print):
189+
"""Helper function to iterate over stream_logs"""
190+
165191
async for _, line in self.stream_logs(stream):
166192
custom_logger(line)
167193

168194
async def cleanup(self):
195+
"""Clean up log files for a running subprocesses."""
196+
169197
if hasattr(self, "temp_dir"):
170198
shutil.rmtree(self.temp_dir, ignore_errors=True)
171199

172-
async def kill(self, termination_timeout=5):
200+
async def kill(self, termination_timeout: float = 5):
201+
"""Kill the subprocess."""
202+
173203
if self.process is not None:
174204
if self.process.returncode is None:
175205
self.process.terminate()

0 commit comments

Comments
 (0)