Skip to content

Commit 4b18c7b

Browse files
committed
Use asyncio for waiting on connection.
1 parent 5cb4cee commit 4b18c7b

File tree

2 files changed

+74
-85
lines changed

2 files changed

+74
-85
lines changed

actions/ci_connection/notify_connection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,26 @@
2626

2727
def timer(conn):
2828
while True:
29-
# We lock because closed and keep_alive could technically arrive at the same time
29+
# We lock because closed and keep_alive could arrive at the same time
3030
with lock:
3131
conn.send("keep_alive")
3232
time.sleep(keep_alive_interval)
3333

3434

3535
if __name__ == "__main__":
3636
with Client(address) as conn:
37-
conn.send("connected")
37+
conn.send("connection_established")
3838

39-
# Thread is running as a daemon so it will quit when the
40-
# main thread terminates.
39+
# Thread is running as a daemon so it will quit
40+
# when the main thread terminates
4141
timer_thread = threading.Thread(target=timer, daemon=True, args=(conn,))
4242
timer_thread.start()
4343

4444
print("Entering interactive bash session")
45-
# Enter interactive bash session
45+
# Enter an interactive Bash session
4646
subprocess.run(["bash", "-i"])
4747

4848
print("Exiting interactive bash session")
4949
with lock:
50-
conn.send("closed")
50+
conn.send("connection_closed")
5151
conn.close()

actions/ci_connection/wait_for_connection.py

Lines changed: 68 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Wait for an SSH connection from a user, if a wait was requested."""
16+
17+
import asyncio
1518
import logging
1619
import os
17-
from multiprocessing.connection import Listener
18-
import time
19-
import threading
2020
import sys
21+
import time
2122

2223
from get_labels import retrieve_labels
2324

@@ -37,17 +38,7 @@
3738
format="%(levelname)s: %(message)s", stream=sys.stderr)
3839

3940

40-
class WaitInfo:
41-
wait_limit_reached = False
42-
timeout = 20
43-
# timeout = 600 # 10 minutes for initial connection
44-
last_time = time.time()
45-
# 30 minutes for keep-alive, if no closed message (allow for reconnects)
46-
# keep_alive_timeout = 900
47-
keep_alive_timeout = 20
48-
49-
50-
def _is_truthy_env_var(var_name: str) -> bool:
41+
def _is_true_like_env_var(var_name: str) -> bool:
5142
var_val = os.getenv(var_name, "").lower()
5243
negative_choices = {"0", "false", "n", "no", "none", "null", "n/a"}
5344
if var_val and var_val not in negative_choices:
@@ -60,12 +51,12 @@ def should_halt_for_connection() -> bool:
6051

6152
logging.info("Checking if the workflow should be halted for a connection...")
6253

63-
if not _is_truthy_env_var("INTERACTIVE_CI"):
54+
if not _is_true_like_env_var("INTERACTIVE_CI"):
6455
logging.info("INTERACTIVE_CI env var is not "
65-
"set, or is set to a falsy value in the workflow")
56+
"set, or is set to a false-like value in the workflow")
6657
return False
6758

68-
explicit_halt_requested = _is_truthy_env_var("HALT_DISPATCH_INPUT")
59+
explicit_halt_requested = _is_true_like_env_var("HALT_DISPATCH_INPUT")
6960
if explicit_halt_requested:
7061
logging.info("Halt for connection requested via "
7162
"explicit `halt-dispatch-input` input")
@@ -96,57 +87,39 @@ def should_halt_for_connection() -> bool:
9687
return False
9788

9889

99-
def wait_for_notification(address):
100-
"""Waits for connection notification from the listener."""
101-
while True:
102-
time.sleep(0.05)
103-
if WaitInfo.wait_limit_reached:
104-
logging.info(f"No connection in {WaitInfo.timeout} seconds - exiting ")
105-
with Listener(address) as listener:
106-
logging.info("Waiting for connection...")
107-
with listener.accept() as conn:
108-
while True:
109-
try:
110-
message = conn.recv()
111-
except EOFError as e:
112-
logging.error("EOFError occurred:", e)
113-
break
114-
logging.info("Received message")
115-
if message == "keep_alive":
116-
logging.info("Keep-alive received")
117-
WaitInfo.last_time = time.time()
118-
continue # Keep-alive received, continue waiting
119-
elif message == "closed":
120-
logging.info("Connection closed by the other process.")
121-
return # Graceful exit
122-
elif message == "connected":
123-
WaitInfo.last_time = time.time()
124-
WaitInfo.timeout = WaitInfo.keep_alive_timeout
125-
logging.info("Connected")
126-
elif message == "wait_limit_reached":
127-
logging.info("Finished waiting")
128-
else:
129-
logging.warning("Unknown message received:", message)
130-
continue
131-
132-
133-
def timer():
134-
while True:
135-
logging.info("Checking status")
136-
time_elapsed = time.time() - WaitInfo.last_time
137-
if time_elapsed < WaitInfo.timeout:
138-
logging.info(f"Time since last keep-alive: {int(time_elapsed)}s")
139-
else:
140-
WaitInfo.wait_limit_reached = True
141-
return
142-
time.sleep(60)
143-
144-
145-
def wait_for_connection():
146-
address = ("localhost", 12455) # Address and port to listen on
147-
90+
class WaitInfo:
91+
pre_connect_timeout = 10 * 60 # 10 minutes for initial connection
92+
# allow for reconnects, in case no 'closed' message is received
93+
re_connect_timeout = 15 * 60 # 15 minutes for reconnects
94+
# Dynamic, depending on whether a connection was established, or not
95+
timeout = pre_connect_timeout
96+
last_time = time.time()
97+
waiting_for_close = False
98+
stop_event = asyncio.Event()
99+
100+
101+
async def process_message(reader, writer):
102+
data = await reader.read(1024)
103+
message = data.decode().strip()
104+
if message == "keep_alive":
105+
logging.info("Keep-alive received")
106+
WaitInfo.last_time = time.time()
107+
elif message == "connection_closed":
108+
WaitInfo.waiting_for_close = True
109+
WaitInfo.stop_event.set()
110+
elif message == "connection_established":
111+
WaitInfo.last_time = time.time()
112+
WaitInfo.timeout = WaitInfo.re_connect_timeout
113+
logging.info("SSH connection detected.")
114+
else:
115+
logging.warning(f"Unknown message received: {message!r}")
116+
writer.close()
117+
118+
119+
async def wait_for_connection(host: str = 'localhost',
120+
port: int = 12455):
148121
# Print out the data required to connect to this VM
149-
host = os.getenv("HOSTNAME")
122+
runner_name = os.getenv("HOSTNAME")
150123
cluster = os.getenv("CONNECTION_CLUSTER")
151124
location = os.getenv("CONNECTION_LOCATION")
152125
ns = os.getenv("CONNECTION_NS")
@@ -156,30 +129,46 @@ def wait_for_connection():
156129
"See go/ml-github-actions:ssh for details")
157130
logging.info(
158131
f"Connection string: ml-actions-connect "
159-
f"--runner={host} "
132+
f"--runner={runner_name} "
160133
f"--ns={ns} "
161134
f"--loc={location} "
162135
f"--cluster={cluster} "
163136
f"--halt_directory={actions_path}"
164137
)
165138

166-
# Thread is running as a daemon, so it will quit when the
167-
# main thread terminates.
168-
timer_thread = threading.Thread(target=timer, daemon=True)
169-
timer_thread.start()
139+
server = await asyncio.start_server(process_message, host, port)
140+
addr = server.sockets[0].getsockname()
141+
terminate = False
142+
143+
logging.info(f"Listening for connection notifications on {addr}...")
144+
async with server:
145+
while not WaitInfo.stop_event.is_set():
146+
await asyncio.wait([asyncio.create_task(WaitInfo.stop_event.wait())],
147+
timeout=60,
148+
return_when=asyncio.FIRST_COMPLETED)
170149

171-
# Wait for connection and get the connection object
172-
wait_for_notification(address)
150+
if WaitInfo.waiting_for_close:
151+
msg = "Connection was terminated."
152+
terminate = True
153+
elif elapsed > WaitInfo.timeout:
154+
terminate = True
155+
msg = f"No connection for {WaitInfo.timeout} seconds."
173156

174-
logging.info("Exiting connection wait loop.")
175-
# Force a flush so we don't miss messages
176-
sys.stdout.flush()
157+
if terminate:
158+
logging.info(f"{msg} Shutting down the waiting process...")
159+
server.close()
160+
await server.wait_closed()
161+
break
162+
163+
elapsed = time.time() - WaitInfo.last_time
164+
logging.info(f"Time since last keep-alive: {int(elapsed)}s")
165+
166+
logging.info("Waiting process terminated.")
177167

178168

179169
if __name__ == "__main__":
180170
if not should_halt_for_connection():
181171
logging.info("No conditions for halting the workflow"
182172
"for connection were met")
183173
exit()
184-
185-
wait_for_connection()
174+
asyncio.run(wait_for_connection())

0 commit comments

Comments
 (0)