Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ci_connection/get_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
github_ref = os.getenv("GITHUB_REF", "")
if not github_ref:
raise TypeError(
"GITHUB_REF is not defined. " "Is this being run outside of GitHub Actions?"
"GITHUB_REF is not defined. Is this being run outside of GitHub Actions?"
)

# Outside a PR context - no labels to be found
Expand All @@ -57,7 +57,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
gh_issue = re.search(r"refs/pull/(\d+)/merge", github_ref).group(1)
gh_repo = os.getenv("GITHUB_REPOSITORY")
labels_url = f"https://api.github.com/repos/{gh_repo}/issues/{gh_issue}/labels"
logging.debug(f"{gh_issue=!r}\n" f"{gh_repo=!r}")
logging.debug(f"{gh_issue=!r}\n{gh_repo=!r}")

wait_time = 3
total_attempts = 3
Expand All @@ -78,7 +78,7 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:

if response.status == 200:
data = response.read().decode("utf-8")
logging.debug("API labels data: \n" f"{data}")
logging.debug("API labels data: \n{data}")
break
else:
logging.error(f"Request failed with status code: {response.status}")
Expand All @@ -96,10 +96,10 @@ def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
with open(event_payload_path, "r", encoding="utf-8") as event_payload:
data_json = json.load(event_payload).get("pull_request", {}).get("labels", [])
logging.info("Using fallback labels")
logging.info(f"Fallback labels: \n" f"{data_json}")
logging.info(f"Fallback labels: \n{data_json}")

labels = [label["name"] for label in data_json]
logging.debug(f"Final labels: \n" f"{labels}")
logging.debug(f"Final labels: \n{labels}")

# Output the labels to stdout for further use elsewhere
if print_to_stdout:
Expand Down
25 changes: 25 additions & 0 deletions ci_connection/logging_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
import os
import sys

# Check if debug logging should be enabled for the scripts:
# WAIT_FOR_CONNECTION_DEBUG is a custom variable.
# RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set
# in various ways, one of them - enabling debug logging from the UI, when
# triggering a run:
# https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
# https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging
_SHOW_DEBUG = bool(
os.getenv(
"WAIT_FOR_CONNECTION_DEBUG",
os.getenv("RUNNER_DEBUG", os.getenv("ACTIONS_RUNNER_DEBUG")),
)
)


def setup_logging():
logging.basicConfig(
level=logging.INFO if not _SHOW_DEBUG else logging.DEBUG,
format="%(levelname)s: %(message)s",
stream=sys.stderr,
)
68 changes: 43 additions & 25 deletions ci_connection/notify_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import socket
import time
import threading
import subprocess
from multiprocessing.connection import Client

lock = threading.Lock()
from logging_setup import setup_logging

setup_logging()

_LOCK = threading.Lock()

# Configuration (same as wait_for_connection.py)
address = ("localhost", 12455)
keep_alive_interval = 30 # 30 seconds
HOST, PORT = "localhost", 12455
KEEP_ALIVE_INTERVAL = 30


def send_message(message: str):
with _LOCK:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
# Append a newline to split the messages on the backend,
# in case multiple ones are received together
try:
sock.connect((HOST, PORT))
sock.sendall(f"{message}\n".encode("utf-8"))
except ConnectionRefusedError:
logging.error(
f"Could not connect to server at {HOST}:{PORT}. Is the server running?"
)
except Exception as e:
logging.error(f"An error occurred: {e}")

def timer(conn):

def keep_alive():
while True:
# We lock because closed and keep_alive could technically arrive at the same time
with lock:
conn.send("keep_alive")
time.sleep(keep_alive_interval)
time.sleep(KEEP_ALIVE_INTERVAL)
send_message("keep_alive")


def main():
send_message("connection_established")

# Thread is running as a daemon so it will quit
# when the main thread terminates
timer_thread = threading.Thread(target=keep_alive, daemon=True)
timer_thread.start()

# Enter an interactive Bash session
subprocess.run(["bash", "-i"])

send_message("connection_closed")


if __name__ == "__main__":
with Client(address) as conn:
conn.send("connected")

# Thread is running as a daemon so it will quit when the
# main thread terminates.
timer_thread = threading.Thread(target=timer, daemon=True, args=(conn,))
timer_thread.start()

print("Entering interactive bash session")
# Enter interactive bash session
subprocess.run(["/bin/bash", "-i"])

print("Exiting interactive bash session")
with lock:
conn.send("closed")
conn.close()
main()
175 changes: 78 additions & 97 deletions ci_connection/wait_for_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wait for an SSH connection from a user, if a wait was requested."""

import asyncio
import logging
import os
from multiprocessing.connection import Listener
import time
import signal
import threading
import sys

from get_labels import retrieve_labels
from logging_setup import setup_logging

setup_logging()

# Check if debug logging should be enabled for the script:
# WAIT_FOR_CONNECTION_DEBUG is a custom variable.
# RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set
# in various ways, one of them - enabling debug logging from the UI, when
# triggering a run:
# https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables
# https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging
_SHOW_DEBUG = bool(
os.getenv(
"WAIT_FOR_CONNECTION_DEBUG",
os.getenv("RUNNER_DEBUG", os.getenv("ACTIONS_RUNNER_DEBUG")),
)
)
logging.basicConfig(
level=logging.INFO if not _SHOW_DEBUG else logging.DEBUG,
format="%(levelname)s: %(message)s",
stream=sys.stderr,
)


last_time = time.time()
timeout = 600 # 10 minutes for initial connection
keep_alive_timeout = (
900 # 15 minutes for keep-alive, if no closed message (allow for reconnects)
)

# Labels that are used for checking whether a workflow should wait for a
# connection.
# Note: there's always a small possibility these labels may change on the
# repo/org level, in which case, they'd need to be updated below as well.
ALWAYS_HALT_LABEL = "CI Connection Halt - Always"
HALT_ON_RETRY_LABEL = "CI Connection Halt - On Retry"


def _is_truthy_env_var(var_name: str) -> bool:
def _is_true_like_env_var(var_name: str) -> bool:
var_val = os.getenv(var_name, "").lower()
negative_choices = {"0", "false", "n", "no", "none", "null", "n/a"}
if var_val and var_val not in negative_choices:
Expand All @@ -69,22 +41,26 @@ def should_halt_for_connection() -> bool:

logging.info("Checking if the workflow should be halted for a connection...")

if not _is_truthy_env_var("INTERACTIVE_CI"):
if not _is_true_like_env_var("INTERACTIVE_CI"):
logging.info(
"INTERACTIVE_CI env var is not " "set, or is set to a falsy value in the workflow"
"INTERACTIVE_CI env var is not "
"set, or is set to a false-like value in the workflow"
)
return False

explicit_halt_requested = _is_truthy_env_var("HALT_DISPATCH_INPUT")
explicit_halt_requested = _is_true_like_env_var("HALT_DISPATCH_INPUT")
if explicit_halt_requested:
logging.info(
"Halt for connection requested via " "explicit `halt-dispatch-input` input"
"Halt for connection requested via explicit `halt-dispatch-input` input"
)
return True

# Check if any of the relevant labels are present
labels = retrieve_labels(print_to_stdout=False)

# Note: there's always a small possibility these labels may change on the
# repo/org level, in which case, they'd need to be updated below as well.

# TODO(belitskiy): Add the ability to halt on CI error.

if ALWAYS_HALT_LABEL in labels:
Expand All @@ -106,85 +82,90 @@ def should_halt_for_connection() -> bool:
return False


def wait_for_notification(address):
"""Waits for connection notification from the listener."""
# TODO(belitskiy): Get rid of globals?
global last_time, timeout
while True:
with Listener(address) as listener:
logging.info("Waiting for connection...")
with listener.accept() as conn:
while True:
try:
message = conn.recv()
except EOFError as e:
logging.error("EOFError occurred:", e)
break
logging.info("Received message")
if message == "keep_alive":
logging.info("Keep-alive received")
last_time = time.time()
continue # Keep-alive received, continue waiting
elif message == "closed":
logging.info("Connection closed by the other process.")
return # Graceful exit
elif message == "connected":
last_time = time.time()
timeout = keep_alive_timeout
logging.info("Connected")
else:
logging.warning("Unknown message received:", message)
continue


def timer():
while True:
logging.info("Checking status")
time_elapsed = time.time() - last_time
if time_elapsed < timeout:
logging.info(f"Time since last keep-alive: {int(time_elapsed)}s")
class WaitInfo:
pre_connect_timeout = 10 * 60 # 10 minutes for initial connection
# allow for reconnects, in case no 'closed' message is received
re_connect_timeout = 15 * 60 # 15 minutes for reconnects
# Dynamic, depending on whether a connection was established, or not
timeout = pre_connect_timeout
last_time = time.time()
waiting_for_close = False
stop_event = asyncio.Event()


async def process_messages(reader, writer):
data = await reader.read(1024)
# Since this is a stream, multiple messages could come in at once
messages = [m for m in data.decode().strip().splitlines() if m]
for message in messages:
if message == "keep_alive":
logging.info("Keep-alive received")
WaitInfo.last_time = time.time()
elif message == "connection_closed":
WaitInfo.waiting_for_close = True
WaitInfo.stop_event.set()
elif message == "connection_established":
WaitInfo.last_time = time.time()
WaitInfo.timeout = WaitInfo.re_connect_timeout
logging.info("SSH connection detected.")
else:
logging.info("Timeout reached, exiting")
os.kill(os.getpid(), signal.SIGTERM)
time.sleep(60)

logging.warning(f"Unknown message received: {message!r}")
writer.close()

def wait_for_connection():
address = ("localhost", 12455) # Address and port to listen on

async def wait_for_connection(host: str = "localhost", port: int = 12455):
# Print out the data required to connect to this VM
host = os.getenv("HOSTNAME")
runner_name = os.getenv("HOSTNAME")
cluster = os.getenv("CONNECTION_CLUSTER")
location = os.getenv("CONNECTION_LOCATION")
ns = os.getenv("CONNECTION_NS")
actions_path = os.getenv("GITHUB_ACTION_PATH")

logging.info("Googler connection only\n" "See go/ml-github-actions:ssh for details")
logging.info("Googler connection only\nSee go/ml-github-actions:ssh for details")
logging.info(
f"Connection string: ml-actions-connect "
f"--runner={host} "
f"--runner={runner_name} "
f"--ns={ns} "
f"--loc={location} "
f"--cluster={cluster} "
f"--halt_directory={actions_path}"
)

# Thread is running as a daemon, so it will quit when the
# main thread terminates.
timer_thread = threading.Thread(target=timer, daemon=True)
timer_thread.start()
server = await asyncio.start_server(process_messages, host, port)
terminate = False

# Wait for connection and get the connection object
wait_for_notification(address)
logging.info(f"Listening for connection notifications on {host}:{port}...")
async with server:
while not WaitInfo.stop_event.is_set():
# Send a status msg every 60 seconds, unless a stop message is received
# from the companion script
await asyncio.wait(
[asyncio.create_task(WaitInfo.stop_event.wait())],
timeout=60,
return_when=asyncio.FIRST_COMPLETED,
)

logging.info("Exiting connection wait loop.")
# Force a flush so we don't miss messages
sys.stdout.flush()
elapsed_seconds = int(time.time() - WaitInfo.last_time)
if WaitInfo.waiting_for_close:
msg = "Connection was terminated."
terminate = True
elif elapsed_seconds > WaitInfo.timeout:
terminate = True
msg = f"No connection for {WaitInfo.timeout} seconds."

if terminate:
logging.info(f"{msg} Shutting down the waiting process...")
server.close()
await server.wait_closed()
break

logging.info(f"Time since last keep-alive: {elapsed_seconds}s")

logging.info("Waiting process terminated.")


if __name__ == "__main__":
if not should_halt_for_connection():
logging.info("No conditions for halting the workflow" "for connection were met")
logging.info("No conditions for halting the workflow for connection were met")
exit()

wait_for_connection()
asyncio.run(wait_for_connection())