Skip to content

Test CI scripts and workflows #110

Test CI scripts and workflows

Test CI scripts and workflows #110

Workflow file for this run

name: Run Pytest TPU tests
on:
pull_request:
branches:
- main
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
jobs:
run_tpu_tests:
strategy:
fail-fast: false
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python: ["3.10"]
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240722
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
JAXCI_PYTHON: python${{ matrix.python }}
steps:
- uses: actions/checkout@v3
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
$JAXCI_PYTHON -m pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
$JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
$JAXCI_PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$JAXCI_PYTHON -m pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
$JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
$JAXCI_PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
$JAXCI_PYTHON -m pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest TPU tests
env:
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
run: ./ci/run_pytest_tpu.sh