Skip to content

Test CI scripts and workflows #290

Test CI scripts and workflows

Test CI scripts and workflows #290

Workflow file for this run

name: CI - Pytest CPU
on:
pull_request:
branches:
- main
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version should the artifact be built for?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: false
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
type: string
jobs:
run-tests:
defaults:
run:
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
shell: bash
runs-on: ${{ inputs.runner }}
container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
(contains(inputs.runner, 'windows-x86') && null) }}
name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
steps:
- uses: actions/checkout@v4
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Adjust os and arch for Windows x86
if [[ $os =~ "msys_nt" ]] && [[ $arch =~ "x86_64" ]]; then
os="win"
arch="amd64"
fi
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download the jaxlib wheel from GCS (non-Windows runs)
if: ${{ !contains(inputs.runner, 'windows-x86') }}
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/
- name: Download the jaxlib wheel from GCS (Windows runs)
if: ${{ contains(inputs.runner, 'windows-x86') }}
shell: cmd
run: >-
mkdir dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/
- name: Install dependencies
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
- name: Run Pytest CPU tests
run: ./ci/run_pytest_cpu.sh