|
1 | | -# Cloud TPU CI |
2 | | -name: Cloud TPU Presubmit |
3 | | -# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch |
| 1 | +name: Build JAX Artifacts |
| 2 | + |
4 | 3 | on: |
5 | 4 | pull_request: |
6 | 5 | branches: |
7 | 6 | - main |
8 | | - types: [labeled, synchronize] |
9 | | - workflow_dispatch: |
10 | | -# Cancel any previous iterations if a new commit is pushed |
11 | | -concurrency: |
12 | | - group: ${{ github.workflow }}-${{ github.ref }} |
13 | | - cancel-in-progress: true |
| 7 | + workflow_call: |
| 8 | + |
14 | 9 | jobs: |
15 | | - cloud-tpu-test: |
16 | | - # TODO: confirm final naming for optional label |
17 | | - if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu') |
18 | | - name: "TPU v5e x 8 Presubmit" |
19 | | - strategy: |
20 | | - fail-fast: false # don't cancel all jobs on failure |
21 | | - matrix: |
22 | | - instances: ["one", "two"] |
23 | | - env: |
24 | | - ENABLE_PJRT_COMPATIBILITY: 1 |
25 | | - # TODO: Needs final runs-on value |
26 | | - runs-on: linux-x86-n2-16 |
27 | | - container: |
28 | | - # TODO: Needs newer, light weight image |
29 | | - image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11 |
30 | | - timeout-minutes: 45 |
| 10 | + build: |
| 11 | + continue-on-error: true |
31 | 12 | defaults: |
32 | 13 | run: |
33 | | - shell: bash -ex {0} |
34 | | - steps: |
35 | | - - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 |
36 | | - - name: Install JAX test requirements |
37 | | - run: | |
38 | | - pip install -U -r build/test-requirements.txt |
39 | | - - name: DEBUG HALT |
40 | | - run: | |
41 | | - echo "Halting" |
42 | | - sleep 30m |
43 | | -
|
44 | | -
|
45 | | -
|
46 | | -
|
47 | | -
|
48 | | -
|
49 | | -
|
| 14 | + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. |
| 15 | + shell: bash |
| 16 | + strategy: |
| 17 | + matrix: |
| 18 | + runner: ["windows-x86-n2-64-dev"] |
| 19 | + artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] |
| 20 | + python: ["3.10", "3.11", "3.12"] |
| 21 | + # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each |
| 22 | + # Python version. |
| 23 | + exclude: |
| 24 | + # Pure Python packages do not need to be built for each Python version. |
| 25 | + - artifact: "jax-cuda-pjrt" |
| 26 | + python: "3.10" |
| 27 | + - artifact: "jax-cuda-pjrt" |
| 28 | + python: "3.11" |
| 29 | + - artifact: "jax" |
| 30 | + python: "3.10" |
| 31 | + - artifact: "jax" |
| 32 | + python: "3.11" |
| 33 | + # jax is a pure Python package so it does not need to be built on multiple platforms. |
| 34 | + - artifact: "jax" |
| 35 | + runner: "windows-x86-n2-64" |
| 36 | + - artifact: "jax" |
| 37 | + runner: "linux-arm64-t2a-48" |
| 38 | + # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. |
| 39 | + - artifact: "jax-cuda-plugin" |
| 40 | + runner: "windows-x86-n2-64" |
| 41 | + - artifact: "jax-cuda-pjrt" |
| 42 | + runner: "windows-x86-n2-64" |
50 | 43 |
|
| 44 | + runs-on: ${{ matrix.runner }} |
51 | 45 |
|
| 46 | + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
| 47 | + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || |
| 48 | + (contains(matrix.runner, 'windows-x86') && null) }} |
52 | 49 |
|
| 50 | + env: |
| 51 | + # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. |
| 52 | + JAXCI_RUN_DOCKER_CONTAINER: 0 |
| 53 | + # Use RBE to build the artifacts where possibl (Linux x86 and Windows). |
| 54 | + JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 |
53 | 55 |
|
| 56 | + steps: |
| 57 | + - uses: actions/checkout@v3 |
| 58 | + # Halt for testing |
| 59 | + - name: Wait For Connection |
| 60 | + uses: ./actions/ci_connection/ |
| 61 | + - name: Build ${{ matrix.artifact }} |
| 62 | + env: |
| 63 | + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" |
| 64 | + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/${{ matrix.artifact }}" |
0 commit comments