1
1
import os
2
2
import sys
3
- import time
4
3
import signal
5
4
import shutil
6
- import hashlib
7
5
import asyncio
8
6
import tempfile
9
- from typing import List
10
-
11
-
12
- def hash_command_invocation (command : List [str ]):
13
- concatenated_string = "" .join (command )
14
- current_time = str (time .time ())
15
- concatenated_string += current_time
16
- hash_object = hashlib .sha256 (concatenated_string .encode ())
17
- return hash_object .hexdigest ()
7
+ from typing import List , Dict , Optional , Callable
18
8
19
9
20
10
class LogReadTimeoutError (Exception ):
11
+ """Exception raised when reading logs times out."""
12
+
21
13
pass
22
14
23
15
24
16
class SubprocessManager (object ):
17
+ """A manager for subprocesses."""
18
+
25
19
def __init__ (self ):
26
- self .commands = {}
20
+ self .commands : Dict [ int , CommandManager ] = {}
27
21
28
- async def __aenter__ (self ):
22
+ async def __aenter__ (self ) -> "SubprocessManager" :
29
23
return self
30
24
31
25
async def __aexit__ (self , exc_type , exc_value , traceback ):
32
26
await self .cleanup ()
33
27
34
- async def run_command (self , command : List [str ], env = None , cwd = None ):
35
- command_id = hash_command_invocation (command )
36
- self .commands [command_id ] = CommandManager (command , env , cwd )
37
- await self .commands [command_id ].run ()
38
- return command_id
28
+ async def run_command (
29
+ self ,
30
+ command : List [str ],
31
+ env : Optional [Dict [str , str ]] = None ,
32
+ cwd : Optional [str ] = None ,
33
+ ) -> int :
34
+ """Run a command asynchronously and return its process ID."""
39
35
40
- def get (self , command_id : str ) -> "CommandManager" :
41
- return self .commands .get (command_id , None )
36
+ command_obj = CommandManager (command , env , cwd )
37
+ process = await command_obj .run ()
38
+ self .commands [process .pid ] = command_obj
39
+ return process .pid
42
40
43
- async def cleanup (self ):
44
- for _ , v in self .commands .items ():
41
+ def get (self , pid : int ) -> "CommandManager" :
42
+ """Get the CommandManager object for a given process ID."""
43
+
44
+ return self .commands .get (pid , None )
45
+
46
+ async def cleanup (self ) -> None :
47
+ """Clean up log files for all running subprocesses."""
48
+
49
+ for v in self .commands .values ():
45
50
await v .cleanup ()
46
51
47
52
48
53
class CommandManager (object ):
49
- def __init__ (self , command : List [str ], env = None , cwd = None ):
50
- self .command = command
54
+ """A manager for an individual subprocess."""
51
55
52
- if env is None :
53
- env = os .environ .copy ()
54
- self .env = env
56
+ def __init__ (
57
+ self ,
58
+ command : List [str ],
59
+ env : Optional [Dict [str , str ]] = None ,
60
+ cwd : Optional [str ] = None ,
61
+ ):
62
+ self .command = command
55
63
56
- if cwd is None :
57
- cwd = os .getcwd ()
58
- self .cwd = cwd
64
+ self .env = env if env is not None else os .environ .copy ()
65
+ self .cwd = cwd if cwd is not None else os .getcwd ()
59
66
60
67
self .process = None
61
- self .run_called = False
62
- self .log_files = {}
68
+ self .run_called : bool = False
69
+ self .log_files : Dict [ str , str ] = {}
63
70
64
71
signal .signal (signal .SIGINT , self .handle_sigint )
65
72
66
- async def __aenter__ (self ):
73
+ async def __aenter__ (self ) -> "CommandManager" :
67
74
return self
68
75
69
76
async def __aexit__ (self , exc_type , exc_value , traceback ):
70
77
await self .cleanup ()
71
78
72
79
def handle_sigint (self , signum , frame ):
80
+ """Handle the SIGINT signal."""
81
+
73
82
print ("SIGINT received." )
74
83
asyncio .create_task (self .kill ())
75
84
76
- async def wait (self , timeout = None , stream = None ):
85
+ async def wait (
86
+ self , timeout : Optional [float ] = None , stream : Optional [str ] = None
87
+ ) -> None :
88
+ """Wait for the subprocess to finish, optionally with a timeout and optionally streaming its output."""
89
+
77
90
if timeout is None :
78
91
if stream is None :
79
92
await self .process .wait ()
80
93
else :
81
94
await self .emit_logs (stream )
82
95
else :
83
- tasks = [asyncio .create_task (asyncio .sleep (timeout ))]
84
- if stream is None :
85
- tasks .append (asyncio .create_task (self .process .wait ()))
86
- else :
87
- tasks .append (asyncio .create_task (self .emit_logs (stream )))
88
-
89
- await asyncio .wait (tasks , return_when = "FIRST_COMPLETED" )
96
+ try :
97
+ if stream is None :
98
+ await asyncio .wait_for (self .process .wait (), timeout )
99
+ else :
100
+ await asyncio .wait_for (self .emit_logs (stream ), timeout )
101
+ except asyncio .TimeoutError :
102
+ command_string = " " .join (self .command )
103
+ print (
104
+ f"Timeout: The process: '{ command_string } ' didn't complete within { timeout } seconds."
105
+ )
90
106
91
107
async def run (self ):
108
+ """Run the subprocess, streaming the logs to temporary files"""
109
+
92
110
self .temp_dir = tempfile .mkdtemp ()
93
111
stdout_logfile = os .path .join (self .temp_dir , "stdout.log" )
94
112
stderr_logfile = os .path .join (self .temp_dir , "stderr.log" )
@@ -114,14 +132,20 @@ async def run(self):
114
132
await self .cleanup ()
115
133
116
134
async def stream_logs (
117
- self , stream , position = None , timeout_per_line = None , log_write_delay = 0.01
135
+ self ,
136
+ stream : str ,
137
+ position : Optional [int ] = None ,
138
+ timeout_per_line : Optional [float ] = None ,
139
+ log_write_delay : float = 0.01 ,
118
140
):
141
+ """Stream logs from the subprocess using the log files"""
142
+
119
143
if self .run_called is False :
120
144
raise ValueError ("No command run yet to get the logs for..." )
121
145
122
146
if stream not in self .log_files :
123
147
raise ValueError (
124
- f"No log file found for { stream } , valid values are: { list (self .log_files .keys ())} "
148
+ f"No log file found for ' { stream } ' , valid values are: { list (self .log_files .keys ())} "
125
149
)
126
150
127
151
log_file = self .log_files [stream ]
@@ -161,15 +185,21 @@ async def stream_logs(
161
185
position = f .tell ()
162
186
yield position , line .strip ()
163
187
164
- async def emit_logs (self , stream = "stdout" , custom_logger = print ):
188
+ async def emit_logs (self , stream : str = "stdout" , custom_logger : Callable = print ):
189
+ """Helper function to iterate over stream_logs"""
190
+
165
191
async for _ , line in self .stream_logs (stream ):
166
192
custom_logger (line )
167
193
168
194
async def cleanup (self ):
195
+ """Clean up log files for a running subprocesses."""
196
+
169
197
if hasattr (self , "temp_dir" ):
170
198
shutil .rmtree (self .temp_dir , ignore_errors = True )
171
199
172
- async def kill (self , termination_timeout = 5 ):
200
+ async def kill (self , termination_timeout : float = 5 ):
201
+ """Kill the subprocess."""
202
+
173
203
if self .process is not None :
174
204
if self .process .returncode is None :
175
205
self .process .terminate ()
0 commit comments