1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ """Wait for an SSH connection from a user, if a wait was requested."""
16+
17+ import asyncio
1518import logging
1619import os
17- from multiprocessing .connection import Listener
1820import time
19- import signal
20- import threading
21- import sys
2221
2322from get_labels import retrieve_labels
23+ from logging_setup import setup_logging
24+
25+ setup_logging ()
2426
25- # Check if debug logging should be enabled for the script:
26- # WAIT_FOR_CONNECTION_DEBUG is a custom variable.
27- # RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set
28- # in various ways, one of them - enabling debug logging from the UI, when
29- # triggering a run:
30- # https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
31- # https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging
32- _SHOW_DEBUG = bool (
33- os .getenv (
34- "WAIT_FOR_CONNECTION_DEBUG" ,
35- os .getenv ("RUNNER_DEBUG" , os .getenv ("ACTIONS_RUNNER_DEBUG" )),
36- )
37- )
38- logging .basicConfig (
39- level = logging .INFO if not _SHOW_DEBUG else logging .DEBUG ,
40- format = "%(levelname)s: %(message)s" ,
41- stream = sys .stderr ,
42- )
43-
44-
45- last_time = time .time ()
46- timeout = 600 # 10 minutes for initial connection
47- keep_alive_timeout = (
48- 900 # 15 minutes for keep-alive, if no closed message (allow for reconnects)
49- )
50-
51- # Labels that are used for checking whether a workflow should wait for a
52- # connection.
53- # Note: there's always a small possibility these labels may change on the
54- # repo/org level, in which case, they'd need to be updated below as well.
5527ALWAYS_HALT_LABEL = "CI Connection Halt - Always"
5628HALT_ON_RETRY_LABEL = "CI Connection Halt - On Retry"
5729
5830
59- def _is_truthy_env_var (var_name : str ) -> bool :
31+ def _is_true_like_env_var (var_name : str ) -> bool :
6032 var_val = os .getenv (var_name , "" ).lower ()
6133 negative_choices = {"0" , "false" , "n" , "no" , "none" , "null" , "n/a" }
6234 if var_val and var_val not in negative_choices :
@@ -69,22 +41,26 @@ def should_halt_for_connection() -> bool:
6941
7042 logging .info ("Checking if the workflow should be halted for a connection..." )
7143
72- if not _is_truthy_env_var ("INTERACTIVE_CI" ):
44+ if not _is_true_like_env_var ("INTERACTIVE_CI" ):
7345 logging .info (
74- "INTERACTIVE_CI env var is not " "set, or is set to a falsy value in the workflow"
46+ "INTERACTIVE_CI env var is not "
47+ "set, or is set to a false-like value in the workflow"
7548 )
7649 return False
7750
78- explicit_halt_requested = _is_truthy_env_var ("HALT_DISPATCH_INPUT" )
51+ explicit_halt_requested = _is_true_like_env_var ("HALT_DISPATCH_INPUT" )
7952 if explicit_halt_requested :
8053 logging .info (
81- "Halt for connection requested via " " explicit `halt-dispatch-input` input"
54+ "Halt for connection requested via explicit `halt-dispatch-input` input"
8255 )
8356 return True
8457
8558 # Check if any of the relevant labels are present
8659 labels = retrieve_labels (print_to_stdout = False )
8760
61+ # Note: there's always a small possibility these labels may change on the
62+ # repo/org level, in which case, they'd need to be updated below as well.
63+
8864 # TODO(belitskiy): Add the ability to halt on CI error.
8965
9066 if ALWAYS_HALT_LABEL in labels :
@@ -106,85 +82,90 @@ def should_halt_for_connection() -> bool:
10682 return False
10783
10884
109- def wait_for_notification (address ):
110- """Waits for connection notification from the listener."""
111- # TODO(belitskiy): Get rid of globals?
112- global last_time , timeout
113- while True :
114- with Listener (address ) as listener :
115- logging .info ("Waiting for connection..." )
116- with listener .accept () as conn :
117- while True :
118- try :
119- message = conn .recv ()
120- except EOFError as e :
121- logging .error ("EOFError occurred:" , e )
122- break
123- logging .info ("Received message" )
124- if message == "keep_alive" :
125- logging .info ("Keep-alive received" )
126- last_time = time .time ()
127- continue # Keep-alive received, continue waiting
128- elif message == "closed" :
129- logging .info ("Connection closed by the other process." )
130- return # Graceful exit
131- elif message == "connected" :
132- last_time = time .time ()
133- timeout = keep_alive_timeout
134- logging .info ("Connected" )
135- else :
136- logging .warning ("Unknown message received:" , message )
137- continue
138-
139-
140- def timer ():
141- while True :
142- logging .info ("Checking status" )
143- time_elapsed = time .time () - last_time
144- if time_elapsed < timeout :
145- logging .info (f"Time since last keep-alive: { int (time_elapsed )} s" )
85+ class WaitInfo :
86+ pre_connect_timeout = 10 * 60 # 10 minutes for initial connection
87+ # allow for reconnects, in case no 'closed' message is received
88+ re_connect_timeout = 15 * 60 # 15 minutes for reconnects
89+ # Dynamic, depending on whether a connection was established, or not
90+ timeout = pre_connect_timeout
91+ last_time = time .time ()
92+ waiting_for_close = False
93+ stop_event = asyncio .Event ()
94+
95+
96+ async def process_messages (reader , writer ):
97+ data = await reader .read (1024 )
98+ # Since this is a stream, multiple messages could come in at once
99+ messages = [m for m in data .decode ().strip ().splitlines () if m ]
100+ for message in messages :
101+ if message == "keep_alive" :
102+ logging .info ("Keep-alive received" )
103+ WaitInfo .last_time = time .time ()
104+ elif message == "connection_closed" :
105+ WaitInfo .waiting_for_close = True
106+ WaitInfo .stop_event .set ()
107+ elif message == "connection_established" :
108+ WaitInfo .last_time = time .time ()
109+ WaitInfo .timeout = WaitInfo .re_connect_timeout
110+ logging .info ("SSH connection detected." )
146111 else :
147- logging .info ("Timeout reached, exiting" )
148- os .kill (os .getpid (), signal .SIGTERM )
149- time .sleep (60 )
150-
112+ logging .warning (f"Unknown message received: { message !r} " )
113+ writer .close ()
151114
152- def wait_for_connection ():
153- address = ("localhost" , 12455 ) # Address and port to listen on
154115
116+ async def wait_for_connection (host : str = "localhost" , port : int = 12455 ):
155117 # Print out the data required to connect to this VM
156- host = os .getenv ("HOSTNAME" )
118+ runner_name = os .getenv ("HOSTNAME" )
157119 cluster = os .getenv ("CONNECTION_CLUSTER" )
158120 location = os .getenv ("CONNECTION_LOCATION" )
159121 ns = os .getenv ("CONNECTION_NS" )
160122 actions_path = os .getenv ("GITHUB_ACTION_PATH" )
161123
162- logging .info ("Googler connection only\n " "See go/ml-github-actions:ssh for details" )
124+ logging .info ("Googler connection only\n See go/ml-github-actions:ssh for details" )
163125 logging .info (
164126 f"Connection string: ml-actions-connect "
165- f"--runner={ host } "
127+ f"--runner={ runner_name } "
166128 f"--ns={ ns } "
167129 f"--loc={ location } "
168130 f"--cluster={ cluster } "
169131 f"--halt_directory={ actions_path } "
170132 )
171133
172- # Thread is running as a daemon, so it will quit when the
173- # main thread terminates.
174- timer_thread = threading .Thread (target = timer , daemon = True )
175- timer_thread .start ()
134+ server = await asyncio .start_server (process_messages , host , port )
135+ terminate = False
176136
177- # Wait for connection and get the connection object
178- wait_for_notification (address )
137+ logging .info (f"Listening for connection notifications on { host } :{ port } ..." )
138+ async with server :
139+ while not WaitInfo .stop_event .is_set ():
140+ # Send a status msg every 60 seconds, unless a stop message is received
141+ # from the companion script
142+ await asyncio .wait (
143+ [asyncio .create_task (WaitInfo .stop_event .wait ())],
144+ timeout = 60 ,
145+ return_when = asyncio .FIRST_COMPLETED ,
146+ )
179147
180- logging .info ("Exiting connection wait loop." )
181- # Force a flush so we don't miss messages
182- sys .stdout .flush ()
148+ elapsed_seconds = int (time .time () - WaitInfo .last_time )
149+ if WaitInfo .waiting_for_close :
150+ msg = "Connection was terminated."
151+ terminate = True
152+ elif elapsed_seconds > WaitInfo .timeout :
153+ terminate = True
154+ msg = f"No connection for { WaitInfo .timeout } seconds."
155+
156+ if terminate :
157+ logging .info (f"{ msg } Shutting down the waiting process..." )
158+ server .close ()
159+ await server .wait_closed ()
160+ break
161+
162+ logging .info (f"Time since last keep-alive: { elapsed_seconds } s" )
163+
164+ logging .info ("Waiting process terminated." )
183165
184166
185167if __name__ == "__main__" :
186168 if not should_halt_for_connection ():
187- logging .info ("No conditions for halting the workflow" " for connection were met" )
169+ logging .info ("No conditions for halting the workflow for connection were met" )
188170 exit ()
189-
190- wait_for_connection ()
171+ asyncio .run (wait_for_connection ())
0 commit comments