Update jax-array-api.yml #876
Workflow file for this run
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Cloud TPU CI | |
| # | |
| # This job currently runs once per day. We use self-hosted TPU runners, so we'd | |
| # have to add more runners to run on every commit. | |
| # | |
| # This job's build matrix runs over several TPU architectures using both the | |
| # latest released jaxlib on PyPi ("pypi_latest") and the latest nightly | |
| # jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned | |
| # to the release for "pypi_latest", or the latest nightly.for "nightly". It | |
| # always locally installs jax from github head (already checked out by the | |
| # Github Actions environment). | |
| name: CI - Cloud TPU (nightly) | |
| # Disable the schedule; Slated for removal, the new test workflow is in | |
| # "wheel_tests_nightly_release.yml" | |
| # on: | |
| # schedule: | |
| # - cron: "0 2,14 * * *" # Run at 7am and 7pm PST | |
| # workflow_dispatch: # allows triggering the workflow run manually | |
| # This should also be set to read-only in the project settings, but it's nice to | |
| # document and enforce the permissions here. | |
| permissions: | |
| contents: read | |
| jobs: | |
| cloud-tpu-test: | |
| strategy: | |
| fail-fast: false # don't cancel all jobs on failure | |
| matrix: | |
| jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"] | |
| tpu: [ | |
| {type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"}, | |
| {type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}, | |
| {type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"} | |
| ] | |
| python-version: ["3.10"] | |
| # Exclude v6e-8 tests for nightly+oldest_supported_libtpu and pypi_latest for resource constraints. | |
| exclude: | |
| - tpu: | |
| type: "v6e-8" | |
| jaxlib-version: "nightly+oldest_supported_libtpu" | |
| - tpu: | |
| type: "v6e-8" | |
| jaxlib-version: "pypi_latest" | |
| name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})" | |
| env: | |
| LIBTPU_OLDEST_VERSION_DATE: 20241205 | |
| PYTHON: python${{ matrix.python-version }} | |
| runs-on: ${{ matrix.tpu.runner }} | |
| container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" | |
| timeout-minutes: 180 | |
| defaults: | |
| run: | |
| shell: bash -ex {0} | |
| steps: | |
| # https://opensource.google/documentation/reference/github/services#actions | |
| # mandates using a specific commit for non-Google actions. We use | |
| # https://github.com/sethvargo/ratchet to pin specific versions. | |
| - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| # Checkout XLA at head, if we're building jaxlib at head. | |
| - name: Checkout XLA at head | |
| uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
| if: ${{ matrix.jaxlib-version == 'head' }} | |
| with: | |
| repository: openxla/xla | |
| path: xla | |
| # We need to mark the GitHub workspace as safe as otherwise git commands will fail. | |
| - name: Mark GitHub workspace as safe | |
| run: | | |
| git config --global --add safe.directory "$GITHUB_WORKSPACE" | |
| - name: Install JAX test requirements | |
| run: | | |
| $PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt | |
| - name: Install JAX | |
| run: | | |
| $PYTHON -m uv pip uninstall jax jaxlib libtpu | |
| if [ "${{ matrix.jaxlib-version }}" == "head" ]; then | |
| # Build and install jaxlib at head | |
| $PYTHON build/build.py build --wheels=jaxlib \ | |
| --bazel_options=--config=rbe_linux_x86_64 \ | |
| --local_xla_path="$(pwd)/xla" \ | |
| --verbose | |
| # Install jaxlib, "jax" at head, and libtpu | |
| $PYTHON -m uv pip install dist/*.whl \ | |
| -U -e . \ | |
| --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then | |
| $PYTHON -m uv pip install .[tpu] \ | |
| -f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
| elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then | |
| $PYTHON -m uv pip install \ | |
| --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ | |
| libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ | |
| requests | |
| elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then | |
| # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. | |
| $PYTHON -m uv pip install \ | |
| --pre . -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ | |
| libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ | |
| -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ | |
| requests | |
| else | |
| echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" | |
| exit 1 | |
| fi | |
| $PYTHON -c 'import sys; print("python version:", sys.version)' | |
| $PYTHON -c 'import jax; print("jax version:", jax.__version__)' | |
| $PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)' | |
| strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on' | |
| $PYTHON -c 'import jax; print("libtpu version:", | |
| jax.lib.xla_bridge.get_backend().platform_version)' | |
| - name: Run tests | |
| env: | |
| JAX_PLATFORMS: tpu,cpu | |
| PY_COLORS: 1 | |
| run: | | |
| # We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not | |
| # guarantee anything about forward compatibility (unless jax.export is used) and the 12 | |
| # week compatibility window accumulates way too many failures. | |
| IGNORE_FLAGS= | |
| if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then | |
| IGNORE_FLAGS="--ignore=tests/pallas" | |
| fi | |
| # Run single-accelerator tests in parallel | |
| JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \ | |
| --deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \ | |
| --maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples | |
| # Run Pallas printing tests, which need to run with I/O capturing disabled. | |
| TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \ | |
| tests/pallas/tpu_pallas_test.py::PallasCallPrintTest | |
| # Run multi-accelerator across all chips | |
| $PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests | |
| - name: Send chat on failure | |
| # Don't notify when testing the workflow from a branch. | |
| if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }} | |
| run: | | |
| curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \ | |
| --header 'Content-Type: application/json' \ | |
| --data-raw "{ | |
| 'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID' | |
| }" |