Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions .github/workflows/wait-for-connection-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this invocation wait for a remote connection?'
required: false
default: '0'
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
# Cancel any previous iterations if a new commit is pushed
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand All @@ -20,7 +24,7 @@ jobs:
strategy:
fail-fast: false
matrix:
runner: ["linux-x86-n2-64","linux-arm64-t2a-48"]
runner: ["linux-x86-n2-16"]
instances: ["1"]
runs-on: ${{ matrix.runner }}
timeout-minutes: 60
Expand All @@ -30,8 +34,14 @@ jobs:
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
# Halt for connection if workflow dispatch is told to or if it is a retry with the label halt_on_retry
- name: Get external actions
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
with:
repository: google-ml-infra/actions
ref: ssh_graceful_timeout
path: cool_actions
- name: Wait For Connection
uses: ./actions/ci_connection/
uses: ./cool_actions/ci_connection/
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Echo
Expand Down
Empty file.
48 changes: 15 additions & 33 deletions actions/ci_connection/action.yaml
Original file line number Diff line number Diff line change
@@ -1,44 +1,26 @@
name: "Wait For Connection"
description: 'Action to wait for connection from user'
description: 'Action to wait for connection from user (condtionally)'
inputs:
halt-dispatch-input:
description: 'Should the action wait for user connection from workflow_dispatch'
description: 'Should the action wait for user connection from workflow_dispatch?'
required: false
default: "0"
should-wait-retry-tag:
description: "Tag that will flag action to wait on reruns if present"
required: false
default: "CI Connection Halt - On Retry"
should-wait-always-tag:
description: "Tag that will flag action to wait on reruns if present"
required: false
default: "CI Connection Halt - Always"
repository:
description: 'Repository name with owner. For example, actions/checkout'
default: ${{ github.repository }}
default: "no"
runs:
using: "composite"
steps:
- name: Print halt conditions
shell: bash
run: |
echo "All labels: ${{ toJSON(github.event.pull_request.labels.*.name) }}"
echo "Halt retry tag: ${{ inputs.should-wait-retry-tag }}"
echo "Halt always tag ${{ inputs.should-wait-always-tag}}"
echo "Should halt input: ${{ inputs.halt-dispatch-input }}"
echo "Reattempt count: ${{ github.run_attempt }}"
echo "PR number ${{ github.event.number }}"
- name: Halt For Connection
- name: Wait for connection (conditionally)
shell: bash
if: |
(contains(github.event.pull_request.labels.*.name, inputs.should-wait-retry-tag) && github.run_attempt > 1) ||
contains(github.event.pull_request.labels.*.name, inputs.should-wait-always-tag) ||
inputs.halt-dispatch-input == '1' ||
inputs.halt-dispatch-input == 'yes'
env:
REPOSITORY: ${{ inputs.repository }}
INTERACTIVE_CI: 1
# This variable is on by default, but can be overridden for convenience,
# in case the workflow calling this action should not be halted, despite
# other labels/inputs.
INTERACTIVE_CI: 1
PYTHONUNBUFFERED: 1
HALT_DISPATCH_INPUT: ${{ inputs.halt-dispatch-input }}
# The calling workflow shouldn't fail in case this step does
continue-on-error: true
run: |
echo "$GITHUB_ACTION_PATH"
python3 $GITHUB_ACTION_PATH/wait_for_connection.py
# Pick an existing Python alias
python_bin=$(which python3 2>/dev/null || which python)
# Wait for the connection, if a halt was requested via a label/input
"$python_bin" "$GITHUB_ACTION_PATH/wait_for_connection.py"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want a script failure to fail the overall job. I think there are a few ways we can go about that but the easiest might just be an || true here to keep a script failure from getting back to the job.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added continue-on-error to the step - I think it's a little clearer

115 changes: 115 additions & 0 deletions actions/ci_connection/get_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Retrieve PR labels, if any.

While these labels are also available via GH context, and the event payload
file, they may be stale:
https://github.com/orgs/community/discussions/39062

Thus, the API is used as the main source, with the event payload file
being the fallback.

The script is only geared towards use within a GH Action run.
"""

import json
import logging
import os
import re
import time
import urllib.request


def retrieve_labels(print_to_stdout: bool = True) -> list[str]:
"""Get the most up-to-date labels.

In case this is not a PR, return an empty list.
"""
# Check if this is a PR (pull request)
github_ref = os.getenv('GITHUB_REF', '')
if not github_ref:
raise TypeError('GITHUB_REF is not defined. '
'Is this being run outside of GitHub Actions?')

# Outside a PR context - no labels to be found
if not github_ref.startswith('refs/pull/'):
logging.debug('Not a PR workflow run, returning an empty label list')
if print_to_stdout:
print([])
return []

# Get the PR number
# Since passing the previous check confirms this is a PR, there's no need
# to safeguard this regex
gh_issue = re.search(r'refs/pull/(\d+)/merge', github_ref).group(1)
gh_repo = os.getenv('GITHUB_REPOSITORY')
labels_url = (
f'https://api.github.com/repos/{gh_repo}/issues/{gh_issue}/labels'
)
logging.debug(f'{gh_issue=!r}\n'
f'{gh_repo=!r}')

wait_time = 3
total_attempts = 3
cur_attempt = 1
data = None

# Try retrieving the labels' info via API
while cur_attempt <= total_attempts:
request = urllib.request.Request(
labels_url,
headers={'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28'}
)
logging.info(f'Retrieving PR labels via API - attempt {cur_attempt}...')
response = urllib.request.urlopen(request)

if response.status == 200:
data = response.read().decode('utf-8')
logging.debug('API labels data: \n'
f'{data}')
break
else:
logging.error(f'Request failed with status code: {response.status}')
cur_attempt += 1
if cur_attempt <= total_attempts:
logging.info(f'Trying again in {wait_time} seconds')
time.sleep(wait_time)

# The null check is probably unnecessary, but rather be safe
if data and data != 'null':
data_json = json.loads(data)
else:
# Fall back on labels from the event's payload, if API failed (unlikely)
event_payload_path = os.getenv('GITHUB_EVENT_PATH')
with open(event_payload_path, 'r', encoding='utf-8') as event_payload:
data_json = json.load(event_payload).get('pull_request',
{}).get('labels', [])
logging.info('Using fallback labels')
logging.info(f'Fallback labels: \n'
f'{data_json}')

labels = [label['name'] for label in data_json]
logging.debug(f'Final labels: \n'
f'{labels}')

# Output the labels to stdout for further use elsewhere
if print_to_stdout:
print(labels)
return labels


if __name__ == '__main__':
retrieve_labels(print_to_stdout=True)
Loading