diff --git a/.github/workflows/wait-for-connection-test.yaml b/.github/workflows/wait-for-connection-test.yaml index c99306c505c2..ddb5f45031c0 100644 --- a/.github/workflows/wait-for-connection-test.yaml +++ b/.github/workflows/wait-for-connection-test.yaml @@ -8,9 +8,13 @@ on: workflow_dispatch: inputs: halt-for-connection: - description: 'Should this invocation wait for a remote connection?' - required: false - default: '0' + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' # Cancel any previous iterations if a new commit is pushed concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -20,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - runner: ["linux-x86-n2-64","linux-arm64-t2a-48"] + runner: ["linux-x86-n2-16"] instances: ["1"] runs-on: ${{ matrix.runner }} timeout-minutes: 60 @@ -30,8 +34,14 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 # Halt for connection if workflow dispatch is told to or if it is a retry with the label halt_on_retry + - name: Get external actions + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + with: + repository: google-ml-infra/actions + ref: ssh_graceful_timeout + path: cool_actions - name: Wait For Connection - uses: ./actions/ci_connection/ + uses: ./cool_actions/ci_connection/ with: halt-dispatch-input: ${{ inputs.halt-for-connection }} - name: Echo diff --git a/actions/ci_connection/__init__.py b/actions/ci_connection/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/actions/ci_connection/action.yaml b/actions/ci_connection/action.yaml index 59dc39278db2..9b2696a29b4d 100644 --- a/actions/ci_connection/action.yaml +++ b/actions/ci_connection/action.yaml @@ -1,44 +1,26 @@ name: "Wait For Connection" -description: 'Action to wait for connection from user' +description: 'Action to wait for connection from user (condtionally)' inputs: halt-dispatch-input: - description: 'Should the action wait for user connection from workflow_dispatch' + description: 'Should the action wait for user connection from workflow_dispatch?' required: false - default: "0" - should-wait-retry-tag: - description: "Tag that will flag action to wait on reruns if present" - required: false - default: "CI Connection Halt - On Retry" - should-wait-always-tag: - description: "Tag that will flag action to wait on reruns if present" - required: false - default: "CI Connection Halt - Always" - repository: - description: 'Repository name with owner. For example, actions/checkout' - default: ${{ github.repository }} + default: "no" runs: using: "composite" steps: - - name: Print halt conditions - shell: bash - run: | - echo "All labels: ${{ toJSON(github.event.pull_request.labels.*.name) }}" - echo "Halt retry tag: ${{ inputs.should-wait-retry-tag }}" - echo "Halt always tag ${{ inputs.should-wait-always-tag}}" - echo "Should halt input: ${{ inputs.halt-dispatch-input }}" - echo "Reattempt count: ${{ github.run_attempt }}" - echo "PR number ${{ github.event.number }}" - - name: Halt For Connection + - name: Wait for connection (conditionally) shell: bash - if: | - (contains(github.event.pull_request.labels.*.name, inputs.should-wait-retry-tag) && github.run_attempt > 1) || - contains(github.event.pull_request.labels.*.name, inputs.should-wait-always-tag) || - inputs.halt-dispatch-input == '1' || - inputs.halt-dispatch-input == 'yes' env: - REPOSITORY: ${{ inputs.repository }} - INTERACTIVE_CI: 1 + # This variable is on by default, but can be overridden for convenience, + # in case the workflow calling this action should not be halted, despite + # other labels/inputs. + INTERACTIVE_CI: 1 PYTHONUNBUFFERED: 1 + HALT_DISPATCH_INPUT: ${{ inputs.halt-dispatch-input }} + # The calling workflow shouldn't fail in case this step does + continue-on-error: true run: | - echo "$GITHUB_ACTION_PATH" - python3 $GITHUB_ACTION_PATH/wait_for_connection.py + # Pick an existing Python alias + python_bin=$(which python3 2>/dev/null || which python) + # Wait for the connection, if a halt was requested via a label/input + "$python_bin" "$GITHUB_ACTION_PATH/wait_for_connection.py" diff --git a/actions/ci_connection/get_labels.py b/actions/ci_connection/get_labels.py new file mode 100644 index 000000000000..47bff1376535 --- /dev/null +++ b/actions/ci_connection/get_labels.py @@ -0,0 +1,115 @@ +# Copyright 2024 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Retrieve PR labels, if any. + +While these labels are also available via GH context, and the event payload +file, they may be stale: +https://github.com/orgs/community/discussions/39062 + +Thus, the API is used as the main source, with the event payload file +being the fallback. + +The script is only geared towards use within a GH Action run. +""" + +import json +import logging +import os +import re +import time +import urllib.request + + +def retrieve_labels(print_to_stdout: bool = True) -> list[str]: + """Get the most up-to-date labels. + + In case this is not a PR, return an empty list. + """ + # Check if this is a PR (pull request) + github_ref = os.getenv('GITHUB_REF', '') + if not github_ref: + raise TypeError('GITHUB_REF is not defined. ' + 'Is this being run outside of GitHub Actions?') + + # Outside a PR context - no labels to be found + if not github_ref.startswith('refs/pull/'): + logging.debug('Not a PR workflow run, returning an empty label list') + if print_to_stdout: + print([]) + return [] + + # Get the PR number + # Since passing the previous check confirms this is a PR, there's no need + # to safeguard this regex + gh_issue = re.search(r'refs/pull/(\d+)/merge', github_ref).group(1) + gh_repo = os.getenv('GITHUB_REPOSITORY') + labels_url = ( + f'https://api.github.com/repos/{gh_repo}/issues/{gh_issue}/labels' + ) + logging.debug(f'{gh_issue=!r}\n' + f'{gh_repo=!r}') + + wait_time = 3 + total_attempts = 3 + cur_attempt = 1 + data = None + + # Try retrieving the labels' info via API + while cur_attempt <= total_attempts: + request = urllib.request.Request( + labels_url, + headers={'Accept': 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28'} + ) + logging.info(f'Retrieving PR labels via API - attempt {cur_attempt}...') + response = urllib.request.urlopen(request) + + if response.status == 200: + data = response.read().decode('utf-8') + logging.debug('API labels data: \n' + f'{data}') + break + else: + logging.error(f'Request failed with status code: {response.status}') + cur_attempt += 1 + if cur_attempt <= total_attempts: + logging.info(f'Trying again in {wait_time} seconds') + time.sleep(wait_time) + + # The null check is probably unnecessary, but rather be safe + if data and data != 'null': + data_json = json.loads(data) + else: + # Fall back on labels from the event's payload, if API failed (unlikely) + event_payload_path = os.getenv('GITHUB_EVENT_PATH') + with open(event_payload_path, 'r', encoding='utf-8') as event_payload: + data_json = json.load(event_payload).get('pull_request', + {}).get('labels', []) + logging.info('Using fallback labels') + logging.info(f'Fallback labels: \n' + f'{data_json}') + + labels = [label['name'] for label in data_json] + logging.debug(f'Final labels: \n' + f'{labels}') + + # Output the labels to stdout for further use elsewhere + if print_to_stdout: + print(labels) + return labels + + +if __name__ == '__main__': + retrieve_labels(print_to_stdout=True) diff --git a/actions/ci_connection/wait_for_connection.py b/actions/ci_connection/wait_for_connection.py index e752d0afb747..38feb758029b 100644 --- a/actions/ci_connection/wait_for_connection.py +++ b/actions/ci_connection/wait_for_connection.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from multiprocessing.connection import Listener import time @@ -19,99 +20,161 @@ import threading import sys +from get_labels import retrieve_labels + +# Check if debug logging should be enabled for the script: +# WAIT_FOR_CONNECTION_DEBUG is a custom variable. +# RUNNER_DEBUG and ACTIONS_RUNNER_DEBUG are GH env vars, which can be set +# in various ways, one of them - enabling debug logging from the UI, when +# triggering a run: +# https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables +# https://docs.github.com/en/actions/monitoring-and-troubleshooting-workflows/troubleshooting-workflows/enabling-debug-logging#enabling-runner-diagnostic-logging +_SHOW_DEBUG = bool( + os.getenv("WAIT_FOR_CONNECTION_DEBUG", + os.getenv("RUNNER_DEBUG", + os.getenv("ACTIONS_RUNNER_DEBUG"))) +) +logging.basicConfig(level=logging.INFO if not _SHOW_DEBUG else logging.DEBUG, + format="%(levelname)s: %(message)s", stream=sys.stderr) + + last_time = time.time() timeout = 600 # 10 minutes for initial connection keep_alive_timeout = ( - 900 # 30 minutes for keep-alive if no closed message (allow for reconnects) + 900 # 30 minutes for keep-alive, if no closed message (allow for reconnects) ) +def _is_truthy_env_var(var_name: str) -> bool: + var_val = os.getenv(var_name, "").lower() + negative_choices = {"0", "false", "n", "no", "none", "null", "n/a"} + if var_val and var_val not in negative_choices: + return True + return False + + +def should_halt_for_connection() -> bool: + """Check if the workflow should wait, due to inputs, vars, and labels.""" + + logging.info("Checking if the workflow should be halted for a connection...") + + if not _is_truthy_env_var("INTERACTIVE_CI"): + logging.info("INTERACTIVE_CI env var is not " + "set, or is set to a falsy value in the workflow") + return False + + explicit_halt_requested = _is_truthy_env_var("HALT_DISPATCH_INPUT") + if explicit_halt_requested: + logging.info("Halt for connection requested via " + "explicit `halt-dispatch-input` input") + return True + + # Check if any of the relevant labels are present + labels = retrieve_labels(print_to_stdout=False) + + # Note: there's always a small possibility these labels may change on the + # repo/org level, in which case, they'd need to be updated below as well. + + # TODO(belitskiy): Add the ability to halt on CI error. + + always_halt_label = "CI Connection Halt - Always" + if always_halt_label in labels: + logging.info(f"Halt for connection requested via presence " + f"of the {always_halt_label!r} label") + return True + + attempt = int(os.getenv("GITHUB_RUN_ATTEMPT")) + halt_on_retry_label = "CI Connection Halt - On Retry" + if attempt > 1 and halt_on_retry_label in labels: + logging.info(f"Halt for connection requested via presence " + f"of the {halt_on_retry_label!r} label, " + f"due to workflow run attempt being 2+ ({attempt})") + return True + + return False + + def wait_for_notification(address): """Waits for connection notification from the listener.""" - global last_time + # TODO(belitskiy): Get rid of globals? + global last_time, timeout while True: with Listener(address) as listener: - print("Waiting for connection") + logging.info("Waiting for connection...") with listener.accept() as conn: while True: try: message = conn.recv() except EOFError as e: - print("EOFError occurred:", e) + logging.error("EOFError occurred:", e) break - print("Received message") + logging.info("Received message") if message == "keep_alive": - print("Keep alive received") + logging.info("Keep-alive received") last_time = time.time() continue # Keep-alive received, continue waiting elif message == "closed": - print("Connection closed by the other process.") + logging.info("Connection closed by the other process.") return # Graceful exit elif message == "connected": last_time = time.time() timeout = keep_alive_timeout - print("Connected") + logging.info("Connected") else: - print("Unknown message received:", message) + logging.warning("Unknown message received:", message) continue def timer(): while True: - print("Checking status") + logging.info("Checking status") time_elapsed = time.time() - last_time if time_elapsed < timeout: - print(f"Time since last keepalive {int(time_elapsed)}s") + logging.info(f"Time since last keep-alive: {int(time_elapsed)}s") else: - print("Timeout reached, exiting") + logging.info("Timeout reached, exiting") os.kill(os.getpid(), signal.SIGTERM) time.sleep(60) -if __name__ == "__main__": +def wait_for_connection(): address = ("localhost", 12455) # Address and port to listen on - # Check if we should wait for the connection - wait_for_connection = False - # if os.environ.get("WAIT_ON_ERROR") == "1": - # print("WAIT_ON_ERROR is set") - # if os.getppid() != 1: - # print("Previous command did not exit with success, waiting for connection") - # wait_for_connection = True - # else: - # print("Previous command exited with success") - # else: - # print("WAIT_ON_ERROR is not set") - - if os.environ.get("INTERACTIVE_CI") == "1": - print("INTERACTIVE_CI is set, waiting for connection") - wait_for_connection = True - else: - print("INTERACTIVE_CI is not set") - - if not wait_for_connection: - print("No condition was met to wait for connection. Continuing Job") - exit(0) - - # Grab and print the data required to connect to this vm - host = os.environ.get("HOSTNAME") - repo = os.environ.get("REPOSITORY") - cluster = os.environ.get("CONNECTION_CLUSTER") - location = os.environ.get("CONNECTION_LOCATION") - ns = os.environ.get("CONNECTION_NS") - actions_path = os.environ.get("GITHUB_ACTION_PATH") - - print("Googler connection only\nSee go/ for details") - print( - f"Connection string: ml-actions-connect --runner={host} --ns={ns} --loc={location} --cluster={cluster} --halt_directory={actions_path}" + + # Print out the data required to connect to this VM + host = os.getenv("HOSTNAME") + cluster = os.getenv("CONNECTION_CLUSTER") + location = os.getenv("CONNECTION_LOCATION") + ns = os.getenv("CONNECTION_NS") + actions_path = os.getenv("GITHUB_ACTION_PATH") + + logging.info("Googler connection only\n" + "See go/ml-github-actions:ssh for details") + logging.info( + f"Connection string: ml-actions-connect " + f"--runner={host} " + f"--ns={ns} " + f"--loc={location} " + f"--cluster={cluster} " + f"--halt_directory={actions_path}" ) - # Thread is running as a daemon so it will quit when the + # Thread is running as a daemon, so it will quit when the # main thread terminates. timer_thread = threading.Thread(target=timer, daemon=True) timer_thread.start() - wait_for_notification(address) # Wait for connection and get the connection object + # Wait for connection and get the connection object + wait_for_notification(address) - print("Exiting connection wait loop.") + logging.info("Exiting connection wait loop.") # Force a flush so we don't miss messages sys.stdout.flush() + + +if __name__ == "__main__": + if not should_halt_for_connection(): + logging.info("No conditions for halting the workflow" + "for connection were met") + exit() + + wait_for_connection()