Skip to content

Commit 8f3f5b3

Browse files
authored
Time out waiting on SSH connection gracefully, use asyncio on receiving end. (#3)
Time out gracefully, use asyncio on receiving end.
1 parent 4f3dcbb commit 8f3f5b3

File tree

4 files changed

+151
-127
lines changed

4 files changed

+151
-127
lines changed

ci_connection/get_labels.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
4141
github_ref = os.getenv("GITHUB_REF", "")
4242
if not github_ref:
4343
raise TypeError(
44-
"GITHUB_REF is not defined. " "Is this being run outside of GitHub Actions?"
44+
"GITHUB_REF is not defined. Is this being run outside of GitHub Actions?"
4545
)
4646

4747
# Outside a PR context - no labels to be found
@@ -57,7 +57,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
5757
gh_issue = re.search(r"refs/pull/(\d+)/merge", github_ref).group(1)
5858
gh_repo = os.getenv("GITHUB_REPOSITORY")
5959
labels_url = f"https://api.github.com/repos/{gh_repo}/issues/{gh_issue}/labels"
60-
logging.debug(f"{gh_issue=!r}\n" f"{gh_repo=!r}")
60+
logging.debug(f"{gh_issue=!r}\n{gh_repo=!r}")
6161

6262
wait_time = 3
6363
total_attempts = 3
@@ -78,7 +78,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
7878

7979
if response.status == 200:
8080
data = response.read().decode("utf-8")
81-
logging.debug("API labels data: \n" f"{data}")
81+
logging.debug("API labels data: \n{data}")
8282
break
8383
else:
8484
logging.error(f"Request failed with status code: {response.status}")
@@ -96,10 +96,10 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
9696
with open(event_payload_path, "r", encoding="utf-8") as event_payload:
9797
data_json = json.load(event_payload).get("pull_request", {}).get("labels", [])
9898
logging.info("Using fallback labels")
99-
logging.info(f"Fallback labels: \n" f"{data_json}")
99+
logging.info(f"Fallback labels: \n{data_json}")
100100

101101
labels = [label["name"] for label in data_json]
102-
logging.debug(f"Final labels: \n" f"{labels}")
102+
logging.debug(f"Final labels: \n{labels}")
103103

104104
# Output the labels to stdout for further use elsewhere
105105
if print_to_stdout:

ci_connection/logging_setup.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
import os
3+
import sys
4+
5+
# Check if debug logging should be enabled for the scripts:
6+
# WAIT_FOR_CONNECTION_DEBUG is a custom variable.
7+
# RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set
8+
# in various ways, one of them - enabling debug logging from the UI, when
9+
# triggering a run:
10+
# https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
11+
# https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging
12+
_SHOW_DEBUG = bool(
13+
os.getenv(
14+
"WAIT_FOR_CONNECTION_DEBUG",
15+
os.getenv("RUNNER_DEBUG", os.getenv("ACTIONS_RUNNER_DEBUG")),
16+
)
17+
)
18+
19+
20+
def setup_logging():
21+
logging.basicConfig(
22+
level=logging.INFO if not _SHOW_DEBUG else logging.DEBUG,
23+
format="%(levelname)s: %(message)s",
24+
stream=sys.stderr,
25+
)

ci_connection/notify_connection.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,58 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
import socket
1517
import time
1618
import threading
1719
import subprocess
18-
from multiprocessing.connection import Client
1920

20-
lock = threading.Lock()
21+
from logging_setup import setup_logging
22+
23+
setup_logging()
24+
25+
_LOCK = threading.Lock()
2126

2227
# Configuration (same as wait_for_connection.py)
23-
address = ("localhost", 12455)
24-
keep_alive_interval = 30 # 30 seconds
28+
HOST, PORT = "localhost", 12455
29+
KEEP_ALIVE_INTERVAL = 30
30+
2531

32+
def send_message(message: str):
33+
with _LOCK:
34+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
35+
# Append a newline to split the messages on the backend,
36+
# in case multiple ones are received together
37+
try:
38+
sock.connect((HOST, PORT))
39+
sock.sendall(f"{message}\n".encode("utf-8"))
40+
except ConnectionRefusedError:
41+
logging.error(
42+
f"Could not connect to server at {HOST}:{PORT}. Is the server running?"
43+
)
44+
except Exception as e:
45+
logging.error(f"An error occurred: {e}")
2646

27-
def timer(conn):
47+
48+
def keep_alive():
2849
while True:
29-
# We lock because closed and keep_alive could technically arrive at the same time
30-
with lock:
31-
conn.send("keep_alive")
32-
time.sleep(keep_alive_interval)
50+
time.sleep(KEEP_ALIVE_INTERVAL)
51+
send_message("keep_alive")
52+
53+
54+
def main():
55+
send_message("connection_established")
56+
57+
# Thread is running as a daemon so it will quit
58+
# when the main thread terminates
59+
timer_thread = threading.Thread(target=keep_alive, daemon=True)
60+
timer_thread.start()
61+
62+
# Enter an interactive Bash session
63+
subprocess.run(["bash", "-i"])
64+
65+
send_message("connection_closed")
3366

3467

3568
if __name__ == "__main__":
36-
with Client(address) as conn:
37-
conn.send("connected")
38-
39-
# Thread is running as a daemon so it will quit when the
40-
# main thread terminates.
41-
timer_thread = threading.Thread(target=timer, daemon=True, args=(conn,))
42-
timer_thread.start()
43-
44-
print("Entering interactive bash session")
45-
# Enter interactive bash session
46-
subprocess.run(["/bin/bash", "-i"])
47-
48-
print("Exiting interactive bash session")
49-
with lock:
50-
conn.send("closed")
51-
conn.close()
69+
main()

ci_connection/wait_for_connection.py

Lines changed: 78 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,51 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Wait for an SSH connection from a user, if a wait was requested."""
16+
17+
import asyncio
1518
import logging
1619
import os
17-
from multiprocessing.connection import Listener
1820
import time
19-
import signal
20-
import threading
21-
import sys
2221

2322
from get_labels import retrieve_labels
23+
from logging_setup import setup_logging
24+
25+
setup_logging()
2426

25-
# Check if debug logging should be enabled for the script:
26-
# WAIT_FOR_CONNECTION_DEBUG is a custom variable.
27-
# RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set
28-
# in various ways, one of them - enabling debug logging from the UI, when
29-
# triggering a run:
30-
# https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
31-
# https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging
32-
_SHOW_DEBUG = bool(
33-
os.getenv(
34-
"WAIT_FOR_CONNECTION_DEBUG",
35-
os.getenv("RUNNER_DEBUG", os.getenv("ACTIONS_RUNNER_DEBUG")),
36-
)
37-
)
38-
logging.basicConfig(
39-
level=logging.INFO if not _SHOW_DEBUG else logging.DEBUG,
40-
format="%(levelname)s: %(message)s",
41-
stream=sys.stderr,
42-
)
43-
44-
45-
last_time = time.time()
46-
timeout = 600 # 10 minutes for initial connection
47-
keep_alive_timeout = (
48-
900 # 15 minutes for keep-alive, if no closed message (allow for reconnects)
49-
)
50-
51-
# Labels that are used for checking whether a workflow should wait for a
52-
# connection.
53-
# Note: there's always a small possibility these labels may change on the
54-
# repo/org level, in which case, they'd need to be updated below as well.
5527
ALWAYS_HALT_LABEL = "CI Connection Halt - Always"
5628
HALT_ON_RETRY_LABEL = "CI Connection Halt - On Retry"
5729

5830

59-
def _is_truthy_env_var(var_name: str) -> bool:
31+
def _is_true_like_env_var(var_name: str) -> bool:
6032
var_val = os.getenv(var_name, "").lower()
6133
negative_choices = {"0", "false", "n", "no", "none", "null", "n/a"}
6234
if var_val and var_val not in negative_choices:
@@ -69,22 +41,26 @@ def should_halt_for_connection() -> bool:
6941

7042
logging.info("Checking if the workflow should be halted for a connection...")
7143

72-
if not _is_truthy_env_var("INTERACTIVE_CI"):
44+
if not _is_true_like_env_var("INTERACTIVE_CI"):
7345
logging.info(
74-
"INTERACTIVE_CI env var is not " "set, or is set to a falsy value in the workflow"
46+
"INTERACTIVE_CI env var is not "
47+
"set, or is set to a false-like value in the workflow"
7548
)
7649
return False
7750

78-
explicit_halt_requested = _is_truthy_env_var("HALT_DISPATCH_INPUT")
51+
explicit_halt_requested = _is_true_like_env_var("HALT_DISPATCH_INPUT")
7952
if explicit_halt_requested:
8053
logging.info(
81-
"Halt for connection requested via " "explicit `halt-dispatch-input` input"
54+
"Halt for connection requested via explicit `halt-dispatch-input` input"
8255
)
8356
return True
8457

8558
# Check if any of the relevant labels are present
8659
labels = retrieve_labels(print_to_stdout=False)
8760

61+
# Note: there's always a small possibility these labels may change on the
62+
# repo/org level, in which case, they'd need to be updated below as well.
63+
8864
# TODO(belitskiy): Add the ability to halt on CI error.
8965

9066
if ALWAYS_HALT_LABEL in labels:
@@ -106,85 +82,90 @@ def should_halt_for_connection() -> bool:
10682
return False
10783

10884

109-
def wait_for_notification(address):
110-
"""Waits for connection notification from the listener."""
111-
# TODO(belitskiy): Get rid of globals?
112-
global last_time, timeout
113-
while True:
114-
with Listener(address) as listener:
115-
logging.info("Waiting for connection...")
116-
with listener.accept() as conn:
117-
while True:
118-
try:
119-
message = conn.recv()
120-
except EOFError as e:
121-
logging.error("EOFError occurred:", e)
122-
break
123-
logging.info("Received message")
124-
if message == "keep_alive":
125-
logging.info("Keep-alive received")
126-
last_time = time.time()
127-
continue # Keep-alive received, continue waiting
128-
elif message == "closed":
129-
logging.info("Connection closed by the other process.")
130-
return # Graceful exit
131-
elif message == "connected":
132-
last_time = time.time()
133-
timeout = keep_alive_timeout
134-
logging.info("Connected")
135-
else:
136-
logging.warning("Unknown message received:", message)
137-
continue
138-
139-
140-
def timer():
141-
while True:
142-
logging.info("Checking status")
143-
time_elapsed = time.time() - last_time
144-
if time_elapsed < timeout:
145-
logging.info(f"Time since last keep-alive: {int(time_elapsed)}s")
85+
class WaitInfo:
86+
pre_connect_timeout = 10 * 60 # 10 minutes for initial connection
87+
# allow for reconnects, in case no 'closed' message is received
88+
re_connect_timeout = 15 * 60 # 15 minutes for reconnects
89+
# Dynamic, depending on whether a connection was established, or not
90+
timeout = pre_connect_timeout
91+
last_time = time.time()
92+
waiting_for_close = False
93+
stop_event = asyncio.Event()
94+
95+
96+
async def process_messages(reader, writer):
97+
data = await reader.read(1024)
98+
# Since this is a stream, multiple messages could come in at once
99+
messages = [m for m in data.decode().strip().splitlines() if m]
100+
for message in messages:
101+
if message == "keep_alive":
102+
logging.info("Keep-alive received")
103+
WaitInfo.last_time = time.time()
104+
elif message == "connection_closed":
105+
WaitInfo.waiting_for_close = True
106+
WaitInfo.stop_event.set()
107+
elif message == "connection_established":
108+
WaitInfo.last_time = time.time()
109+
WaitInfo.timeout = WaitInfo.re_connect_timeout
110+
logging.info("SSH connection detected.")
146111
else:
147-
logging.info("Timeout reached, exiting")
148-
os.kill(os.getpid(), signal.SIGTERM)
149-
time.sleep(60)
150-
112+
logging.warning(f"Unknown message received: {message!r}")
113+
writer.close()
151114

152-
def wait_for_connection():
153-
address = ("localhost", 12455) # Address and port to listen on
154115

116+
async def wait_for_connection(host: str = "localhost", port: int = 12455):
155117
# Print out the data required to connect to this VM
156-
host = os.getenv("HOSTNAME")
118+
runner_name = os.getenv("HOSTNAME")
157119
cluster = os.getenv("CONNECTION_CLUSTER")
158120
location = os.getenv("CONNECTION_LOCATION")
159121
ns = os.getenv("CONNECTION_NS")
160122
actions_path = os.getenv("GITHUB_ACTION_PATH")
161123

162-
logging.info("Googler connection only\n" "See go/ml-github-actions:ssh for details")
124+
logging.info("Googler connection only\nSee go/ml-github-actions:ssh for details")
163125
logging.info(
164126
f"Connection string: ml-actions-connect "
165-
f"--runner={host} "
127+
f"--runner={runner_name} "
166128
f"--ns={ns} "
167129
f"--loc={location} "
168130
f"--cluster={cluster} "
169131
f"--halt_directory={actions_path}"
170132
)
171133

172-
# Thread is running as a daemon, so it will quit when the
173-
# main thread terminates.
174-
timer_thread = threading.Thread(target=timer, daemon=True)
175-
timer_thread.start()
134+
server = await asyncio.start_server(process_messages, host, port)
135+
terminate = False
176136

177-
# Wait for connection and get the connection object
178-
wait_for_notification(address)
137+
logging.info(f"Listening for connection notifications on {host}:{port}...")
138+
async with server:
139+
while not WaitInfo.stop_event.is_set():
140+
# Send a status msg every 60 seconds, unless a stop message is received
141+
# from the companion script
142+
await asyncio.wait(
143+
[asyncio.create_task(WaitInfo.stop_event.wait())],
144+
timeout=60,
145+
return_when=asyncio.FIRST_COMPLETED,
146+
)
179147

180-
logging.info("Exiting connection wait loop.")
181-
# Force a flush so we don't miss messages
182-
sys.stdout.flush()
148+
elapsed_seconds = int(time.time() - WaitInfo.last_time)
149+
if WaitInfo.waiting_for_close:
150+
msg = "Connection was terminated."
151+
terminate = True
152+
elif elapsed_seconds > WaitInfo.timeout:
153+
terminate = True
154+
msg = f"No connection for {WaitInfo.timeout} seconds."
155+
156+
if terminate:
157+
logging.info(f"{msg} Shutting down the waiting process...")
158+
server.close()
159+
await server.wait_closed()
160+
break
161+
162+
logging.info(f"Time since last keep-alive: {elapsed_seconds}s")
163+
164+
logging.info("Waiting process terminated.")
183165

184166

185167
if __name__ == "__main__":
186168
if not should_halt_for_connection():
187-
logging.info("No conditions for halting the workflow" "for connection were met")
169+
logging.info("No conditions for halting the workflow for connection were met")
188170
exit()
189-
190-
wait_for_connection()
171+
asyncio.run(wait_for_connection())

0 commit comments

Comments
 (0)