forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
149 lines (144 loc) · 6.26 KB
/
bazel_test_tpu.yml
File metadata and controls
149 lines (144 loc) · 6.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# CI - Bazel test TPU
#
# This workflow runs the TPU tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Bazel TPU tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib wheel from a GCS bucket if build_jaxlib is false.
# - Downloads the libtpu wheels.
# - Executes the `run_bazel_test_tpu.sh` script, which performs the following actions:
# - Runs the TPU tests with Bazel.
name: CI - Bazel test TPU
on:
workflow_call:
inputs:
# Note that the values for runners, cores, and tpu-type are linked to each other.
# For example, the v5e-8 TPU type requires 8 cores. For ease of reference, we use the
# following mapping:
# {tpu-type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
# {tpu-type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-ct5lp-224-8tpu"
cores:
description: "How many TPU cores should the test use?"
type: string
default: "8"
tpu-type:
description: "Which TPU type is used for testing?"
type: string
default: "v5e-8"
python:
description: "Which Python version should be used for testing?"
type: string
default: "3.12"
run-full-tpu-test-suite:
description: "Should the full TPU test suite be run?"
type: string
default: "0"
libtpu-version-type:
description: "Which libtpu version should be used for testing?"
type: string
# Choices are:
# - "nightly": Use the nightly libtpu wheel.
# - "pypi_latest": Use the latest libtpu wheel from PyPI.
# - "oldest_supported_libtpu": Use the oldest supported libtpu wheel.
default: "nightly"
jaxlib-version:
description: "Which jaxlib version to test? (head/pypi_latest)"
required: false
type: string
default: "head"
skip-download-jaxlib-from-gcs:
description: "Whether to skip downloading the jaxlib artifact from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
required: false
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
clone_main_xla:
description: "Should latest XLA be used?"
type: string
required: true
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
build_jaxlib:
description: 'Should jaxlib be built from source?'
required: false
default: 'false'
type: string
build_jax:
description: 'Should jax be built from source?'
required: false
default: 'false'
type: string
permissions: {}
env:
PIP_INDEX_URL: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"
jobs:
run-bazel-tests:
defaults:
run:
shell: bash
runs-on: ${{ inputs.runner }}
container: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest"
# Begin Presubmit Naming Check - name modification requires internal check to be updated
name: "${{ inputs.tpu-type }}, py ${{ inputs.python }}, libtpu=${{ inputs.libtpu-version-type }}"
# End Presubmit Naming Check github-tpu-presubmits
env:
LIBTPU_OLDEST_VERSION_DATE: 20250228
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_RUN_FULL_TPU_TEST_SUITE: "${{ inputs.run-full-tpu-test-suite }}"
JAXCI_TPU_CORES: "${{ inputs.cores }}"
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX wheels
if: inputs.build_jaxlib == 'false'
uses: ./.github/actions/download-jax-cpu-wheels
with:
runner: ${{ inputs.runner }}
python: ${{ inputs.python }}
skip-download-jaxlib-from-gcs: ${{ inputs.skip-download-jaxlib-from-gcs }}
jaxlib-version: ${{ inputs.jaxlib-version }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
- name: Download libtpu wheels
run: |
mkdir -p $(pwd)/dist
$JAXCI_PYTHON -m pip install --upgrade pip
echo "Download the wheel into a local directory"
if [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "nightly" ]]; then
$JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "pypi_latest" ]]; then
echo "Using latest libtpu from PyPI"
$JAXCI_PYTHON -m pip download -d $(pwd)/dist libtpu
elif [[ "${INPUTS_LIBTPU_VERSION_TYPE}" == "oldest_supported_libtpu" ]]; then
echo "Using oldest supported libtpu"
$JAXCI_PYTHON -m pip download -d $(pwd)/dist --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
echo "libtpu_version_type=oldest_supported_libtpu" >> $GITHUB_ENV
else
echo "Unknown libtpu version type: ${INPUTS_LIBTPU_VERSION_TYPE}"
exit 1
fi
env:
INPUTS_LIBTPU_VERSION_TYPE: ${{ inputs.libtpu-version-type }}
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel TPU tests with build_jaxlib=${{ format('{0}', inputs.build_jaxlib) }}
timeout-minutes: ${{ github.event_name == 'pull_request' && 20 || 210 }}
run: ./ci/run_bazel_test_tpu.sh