Skip to content

Commit b85462d

Browse files
authored
Update cloud-tpu-presubmit.yml
1 parent e6db8e2 commit b85462d

File tree

1 file changed

+53
-42
lines changed

1 file changed

+53
-42
lines changed
Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,64 @@
1-
# Cloud TPU CI
2-
name: Cloud TPU Presubmit
3-
# Run on pull_request that is labeled as "optional_ci_tpu" or workflow dispatch
1+
name: Build JAX Artifacts
2+
43
on:
54
pull_request:
65
branches:
76
- main
8-
types: [labeled, synchronize]
9-
workflow_dispatch:
10-
# Cancel any previous iterations if a new commit is pushed
11-
concurrency:
12-
group: ${{ github.workflow }}-${{ github.ref }}
13-
cancel-in-progress: true
7+
workflow_call:
8+
149
jobs:
15-
cloud-tpu-test:
16-
# TODO: confirm final naming for optional label
17-
if: contains(github.event.pull_request.labels.*.name, 'optional_ci_tpu')
18-
name: "TPU v5e x 8 Presubmit"
19-
strategy:
20-
fail-fast: false # don't cancel all jobs on failure
21-
matrix:
22-
instances: ["one", "two"]
23-
env:
24-
ENABLE_PJRT_COMPATIBILITY: 1
25-
# TODO: Needs final runs-on value
26-
runs-on: linux-x86-n2-16
27-
container:
28-
# TODO: Needs newer, light weight image
29-
image: index.docker.io/tensorflow/build@sha256:7fb38f0319bda36393cad7f40670aa22352b44421bb906f5cf34d543acd8e1d2 # ratchet:tensorflow/build:latest-python3.11
30-
timeout-minutes: 45
10+
build:
11+
continue-on-error: true
3112
defaults:
3213
run:
33-
shell: bash -ex {0}
34-
steps:
35-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
36-
- name: Install JAX test requirements
37-
run: |
38-
pip install -U -r build/test-requirements.txt
39-
- name: DEBUG HALT
40-
run: |
41-
echo "Halting"
42-
sleep 30m
43-
44-
45-
46-
47-
48-
49-
14+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
15+
shell: bash
16+
strategy:
17+
matrix:
18+
runner: ["windows-x86-n2-64-dev"]
19+
artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]
20+
python: ["3.10", "3.11", "3.12"]
21+
# jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
22+
# Python version.
23+
exclude:
24+
# Pure Python packages do not need to be built for each Python version.
25+
- artifact: "jax-cuda-pjrt"
26+
python: "3.10"
27+
- artifact: "jax-cuda-pjrt"
28+
python: "3.11"
29+
- artifact: "jax"
30+
python: "3.10"
31+
- artifact: "jax"
32+
python: "3.11"
33+
# jax is a pure Python package so it does not need to be built on multiple platforms.
34+
- artifact: "jax"
35+
runner: "windows-x86-n2-64"
36+
- artifact: "jax"
37+
runner: "linux-arm64-t2a-48"
38+
# jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows.
39+
- artifact: "jax-cuda-plugin"
40+
runner: "windows-x86-n2-64"
41+
- artifact: "jax-cuda-pjrt"
42+
runner: "windows-x86-n2-64"
5043

44+
runs-on: ${{ matrix.runner }}
5145

46+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
47+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
48+
(contains(matrix.runner, 'windows-x86') && null) }}
5249

50+
env:
51+
# Do not run Docker container for Linux runners. Linux runners already run in a Docker container.
52+
JAXCI_RUN_DOCKER_CONTAINER: 0
53+
# Use RBE to build the artifacts where possibl (Linux x86 and Windows).
54+
JAXCI_BUILD_ARTIFACT_WITH_RBE: 1
5355

56+
steps:
57+
- uses: actions/checkout@v3
58+
# Halt for testing
59+
- name: Wait For Connection
60+
uses: ./actions/ci_connection/
61+
- name: Build ${{ matrix.artifact }}
62+
env:
63+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
64+
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/${{ matrix.artifact }}"

0 commit comments

Comments
 (0)