1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ """Wait for an SSH connection from a user, if a wait was requested."""
16+
17+ import asyncio
1518import logging
1619import os
17- from multiprocessing .connection import Listener
18- import time
19- import threading
2020import sys
21+ import time
2122
2223from get_labels import retrieve_labels
2324
3738 format = "%(levelname)s: %(message)s" , stream = sys .stderr )
3839
3940
40- class WaitInfo :
41- wait_limit_reached = False
42- timeout = 20
43- # timeout = 600 # 10 minutes for initial connection
44- last_time = time .time ()
45- # 30 minutes for keep-alive, if no closed message (allow for reconnects)
46- # keep_alive_timeout = 900
47- keep_alive_timeout = 20
48-
49-
50- def _is_truthy_env_var (var_name : str ) -> bool :
41+ def _is_true_like_env_var (var_name : str ) -> bool :
5142 var_val = os .getenv (var_name , "" ).lower ()
5243 negative_choices = {"0" , "false" , "n" , "no" , "none" , "null" , "n/a" }
5344 if var_val and var_val not in negative_choices :
@@ -60,12 +51,12 @@ def should_halt_for_connection() -> bool:
6051
6152 logging .info ("Checking if the workflow should be halted for a connection..." )
6253
63- if not _is_truthy_env_var ("INTERACTIVE_CI" ):
54+ if not _is_true_like_env_var ("INTERACTIVE_CI" ):
6455 logging .info ("INTERACTIVE_CI env var is not "
65- "set, or is set to a falsy value in the workflow" )
56+ "set, or is set to a false-like value in the workflow" )
6657 return False
6758
68- explicit_halt_requested = _is_truthy_env_var ("HALT_DISPATCH_INPUT" )
59+ explicit_halt_requested = _is_true_like_env_var ("HALT_DISPATCH_INPUT" )
6960 if explicit_halt_requested :
7061 logging .info ("Halt for connection requested via "
7162 "explicit `halt-dispatch-input` input" )
@@ -96,57 +87,39 @@ def should_halt_for_connection() -> bool:
9687 return False
9788
9889
99- def wait_for_notification (address ):
100- """Waits for connection notification from the listener."""
101- while True :
102- time .sleep (0.05 )
103- if WaitInfo .wait_limit_reached :
104- logging .info (f"No connection in { WaitInfo .timeout } seconds - exiting " )
105- with Listener (address ) as listener :
106- logging .info ("Waiting for connection..." )
107- with listener .accept () as conn :
108- while True :
109- try :
110- message = conn .recv ()
111- except EOFError as e :
112- logging .error ("EOFError occurred:" , e )
113- break
114- logging .info ("Received message" )
115- if message == "keep_alive" :
116- logging .info ("Keep-alive received" )
117- WaitInfo .last_time = time .time ()
118- continue # Keep-alive received, continue waiting
119- elif message == "closed" :
120- logging .info ("Connection closed by the other process." )
121- return # Graceful exit
122- elif message == "connected" :
123- WaitInfo .last_time = time .time ()
124- WaitInfo .timeout = WaitInfo .keep_alive_timeout
125- logging .info ("Connected" )
126- elif message == "wait_limit_reached" :
127- logging .info ("Finished waiting" )
128- else :
129- logging .warning ("Unknown message received:" , message )
130- continue
131-
132-
133- def timer ():
134- while True :
135- logging .info ("Checking status" )
136- time_elapsed = time .time () - WaitInfo .last_time
137- if time_elapsed < WaitInfo .timeout :
138- logging .info (f"Time since last keep-alive: { int (time_elapsed )} s" )
139- else :
140- WaitInfo .wait_limit_reached = True
141- return
142- time .sleep (60 )
143-
144-
145- def wait_for_connection ():
146- address = ("localhost" , 12455 ) # Address and port to listen on
147-
90+ class WaitInfo :
91+ pre_connect_timeout = 10 * 60 # 10 minutes for initial connection
92+ # allow for reconnects, in case no 'closed' message is received
93+ re_connect_timeout = 15 * 60 # 15 minutes for reconnects
94+ # Dynamic, depending on whether a connection was established, or not
95+ timeout = pre_connect_timeout
96+ last_time = time .time ()
97+ waiting_for_close = False
98+ stop_event = asyncio .Event ()
99+
100+
101+ async def process_message (reader , writer ):
102+ data = await reader .read (1024 )
103+ message = data .decode ().strip ()
104+ if message == "keep_alive" :
105+ logging .info ("Keep-alive received" )
106+ WaitInfo .last_time = time .time ()
107+ elif message == "connection_closed" :
108+ WaitInfo .waiting_for_close = True
109+ WaitInfo .stop_event .set ()
110+ elif message == "connection_established" :
111+ WaitInfo .last_time = time .time ()
112+ WaitInfo .timeout = WaitInfo .re_connect_timeout
113+ logging .info ("SSH connection detected." )
114+ else :
115+ logging .warning (f"Unknown message received: { message !r} " )
116+ writer .close ()
117+
118+
119+ async def wait_for_connection (host : str = 'localhost' ,
120+ port : int = 12455 ):
148121 # Print out the data required to connect to this VM
149- host = os .getenv ("HOSTNAME" )
122+ runner_name = os .getenv ("HOSTNAME" )
150123 cluster = os .getenv ("CONNECTION_CLUSTER" )
151124 location = os .getenv ("CONNECTION_LOCATION" )
152125 ns = os .getenv ("CONNECTION_NS" )
@@ -156,30 +129,46 @@ def wait_for_connection():
156129 "See go/ml-github-actions:ssh for details" )
157130 logging .info (
158131 f"Connection string: ml-actions-connect "
159- f"--runner={ host } "
132+ f"--runner={ runner_name } "
160133 f"--ns={ ns } "
161134 f"--loc={ location } "
162135 f"--cluster={ cluster } "
163136 f"--halt_directory={ actions_path } "
164137 )
165138
166- # Thread is running as a daemon, so it will quit when the
167- # main thread terminates.
168- timer_thread = threading .Thread (target = timer , daemon = True )
169- timer_thread .start ()
139+ server = await asyncio .start_server (process_message , host , port )
140+ addr = server .sockets [0 ].getsockname ()
141+ terminate = False
142+
143+ logging .info (f"Listening for connection notifications on { addr } ..." )
144+ async with server :
145+ while not WaitInfo .stop_event .is_set ():
146+ await asyncio .wait ([asyncio .create_task (WaitInfo .stop_event .wait ())],
147+ timeout = 60 ,
148+ return_when = asyncio .FIRST_COMPLETED )
170149
171- # Wait for connection and get the connection object
172- wait_for_notification (address )
150+ if WaitInfo .waiting_for_close :
151+ msg = "Connection was terminated."
152+ terminate = True
153+ elif elapsed > WaitInfo .timeout :
154+ terminate = True
155+ msg = f"No connection for { WaitInfo .timeout } seconds."
173156
174- logging .info ("Exiting connection wait loop." )
175- # Force a flush so we don't miss messages
176- sys .stdout .flush ()
157+ if terminate :
158+ logging .info (f"{ msg } Shutting down the waiting process..." )
159+ server .close ()
160+ await server .wait_closed ()
161+ break
162+
163+ elapsed = time .time () - WaitInfo .last_time
164+ logging .info (f"Time since last keep-alive: { int (elapsed )} s" )
165+
166+ logging .info ("Waiting process terminated." )
177167
178168
179169if __name__ == "__main__" :
180170 if not should_halt_for_connection ():
181171 logging .info ("No conditions for halting the workflow"
182172 "for connection were met" )
183173 exit ()
184-
185- wait_for_connection ()
174+ asyncio .run (wait_for_connection ())
0 commit comments