Skip to content

Commit 92a7cae

Browse files
Add CPU wheel build and tests in JAX presubmit.
The jobs are needed to replicate postsubmit Github Workflow actions and catch the majority of the issues in the presubmit. The tests do not block PRs yet. PiperOrigin-RevId: 795222280
1 parent add3ce8 commit 92a7cae

File tree

9 files changed

+399
-54
lines changed

9 files changed

+399
-54
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# CI - Bazel build wheels (RBE)
2+
#
3+
# This workflow builds all jax wheels using Bazel.
4+
#
5+
# It consists of the following jobs:
6+
# build-jax-artifact:
7+
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
8+
# build-jaxlib-artifact:
9+
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
10+
# build-cuda-artifacts:
11+
# - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml.
12+
13+
name: CI - Bazel build wheels (RBE)
14+
15+
on:
16+
workflow_dispatch: # allows triggering the workflow run manually
17+
inputs:
18+
runner:
19+
description: "Which runner should the workflow run on?"
20+
type: choice
21+
default: "linux-x86-n2-16"
22+
options:
23+
- "linux-x86-n2-16"
24+
- "linux-arm64-t2a-48"
25+
- "linux-arm64-c4a-16"
26+
- "windows-x86-n2-16"
27+
artifact:
28+
description: "Which JAX artifact to build?"
29+
type: choice
30+
default: "jaxlib"
31+
options:
32+
- "jax"
33+
- "jaxlib"
34+
- "jax-cuda-plugin"
35+
- "jax-cuda-pjrt"
36+
python:
37+
description: "Which python version should the artifact be built for?"
38+
type: choice
39+
default: "3.12"
40+
options:
41+
- "3.11"
42+
- "3.12"
43+
- "3.13"
44+
- "3.14"
45+
- "3.13-nogil"
46+
- "3.14-nogil"
47+
upload_artifacts_to_gcs:
48+
description: "Should the artifacts be uploaded to a GCS bucket?"
49+
default: true
50+
type: boolean
51+
gcs_upload_uri:
52+
description: "GCS location prefix to where the artifacts should be uploaded"
53+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
54+
type: string
55+
clone_main_xla:
56+
description: "Should latest XLA be used?"
57+
type: choice
58+
default: "0"
59+
options:
60+
- "1"
61+
- "0"
62+
halt-for-connection:
63+
description: 'Should this workflow run wait for a remote connection?'
64+
type: choice
65+
default: 'no'
66+
options:
67+
- 'yes'
68+
- 'no'
69+
workflow_call:
70+
inputs:
71+
runner:
72+
description: "Which runner should the workflow run on?"
73+
type: string
74+
default: "linux-x86-n2-16"
75+
artifact:
76+
description: "Which JAX artifact to build?"
77+
type: string
78+
default: "jaxlib"
79+
python:
80+
description: "Which python version should the artifact be built for?"
81+
type: string
82+
default: "3.12"
83+
upload_artifacts_to_gcs:
84+
description: "Should the artifacts be uploaded to a GCS bucket?"
85+
default: true
86+
type: boolean
87+
gcs_upload_uri:
88+
description: "GCS location prefix to where the artifacts should be uploaded"
89+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
90+
type: string
91+
clone_main_xla:
92+
description: "Should latest XLA be used?"
93+
type: string
94+
default: "0"
95+
outputs:
96+
gcs_upload_uri:
97+
description: "GCS location prefix to where the artifacts were uploaded"
98+
value: ${{ jobs.build-jax-artifact.outputs.gcs_upload_uri }}
99+
100+
permissions: {}
101+
102+
jobs:
103+
build-jax-artifact:
104+
if: ${{ startsWith(inputs.runner, 'linux-x86') && inputs.artifact == 'jax' }}
105+
uses: ./.github/workflows/build_artifacts.yml
106+
name: "Build jax artifact"
107+
with:
108+
# Note that since jax is a pure python package, the runner OS and Python values do not
109+
# matter. In addition, cloning main XLA also has no effect.
110+
runner: ${{ inputs.runner }}
111+
artifact: "jax"
112+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
113+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
114+
115+
build-jaxlib-artifact:
116+
if: ${{ inputs.artifact == 'jaxlib' }}
117+
uses: ./.github/workflows/build_artifacts.yml
118+
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
119+
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
120+
# values to the name and creates a separate entry for each matrix combination.
121+
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
122+
with:
123+
runner: ${{ inputs.runner }}
124+
artifact: "jaxlib"
125+
python: ${{ inputs.python }}
126+
clone_main_xla: ${{ inputs.clone_main_xla }}
127+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
128+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
129+
130+
build-cuda-artifacts:
131+
if: ${{ startsWith(inputs.runner, 'linux') && (inputs.artifact == 'jax-cuda-plugin' || inputs.artifact == 'jax-cuda-pjrt') }}
132+
uses: ./.github/workflows/build_artifacts.yml
133+
strategy:
134+
fail-fast: false # don't cancel all jobs on failure
135+
matrix:
136+
# Python values need to match the matrix stategy in the CUDA tests job below
137+
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
138+
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
139+
with:
140+
runner: ${{ inputs.runner }}
141+
artifact: ${{ matrix.artifact }}
142+
python: ${{ inputs.python }}
143+
clone_main_xla: ${{ inputs.clone_main_xla }}
144+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
145+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}

.github/workflows/bazel_cpu_rbe.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
3535
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
3636
JAXCI_BUILD_JAXLIB: "true"
37+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: "1"
3738
# Begin Presubmit Naming Check - name modification requires internal check to be updated
3839
strategy:
3940
matrix:

.github/workflows/bazel_cpu_rbe_no_jaxlib_build.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ on:
3838
required: false
3939
default: 'false'
4040
type: string
41+
run-backend_independent-test-suite:
42+
description: "Should the backend independent test suite be run?"
43+
type: string
44+
default: "1"
4145
gcs_download_uri:
4246
description: "GCS location prefix from where the artifacts should be downloaded"
4347
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
@@ -57,6 +61,7 @@ jobs:
5761
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
5862
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
5963
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
64+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: "${{ inputs.run-backend_independent-test-suite }}"
6065

6166
name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
6267
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') ||
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# CI - Bazel wheel tests (RBE)
2+
#
3+
# This workflow is triggered in the presubmit.
4+
#
5+
# It consists of the following jobs:
6+
# build-jax-artifact:
7+
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
8+
# build-jaxlib-artifact:
9+
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
10+
# run-bazel-test-cpu-py-import:
11+
# - Runs the Bazel CPU tests with py_import dependencies. Uses bazel_cpu_rbe_no_jaxlib_build.yml.
12+
13+
name: CI - Bazel wheel tests (RBE)
14+
15+
on:
16+
workflow_dispatch:
17+
inputs:
18+
halt-for-connection:
19+
description: 'Should this workflow run wait for a remote connection?'
20+
type: choice
21+
required: true
22+
default: 'no'
23+
options:
24+
- 'yes'
25+
- 'no'
26+
pull_request:
27+
branches:
28+
- main
29+
push:
30+
branches:
31+
- main
32+
- 'release/**'
33+
permissions: {}
34+
concurrency:
35+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
36+
# Don't cancel in-progress jobs for main/release branches.
37+
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
38+
39+
jobs:
40+
build-cpu-wheels:
41+
uses: ./.github/workflows/bazel_build_wheels.yml
42+
name: "Build JAX CPU wheels"
43+
strategy:
44+
fail-fast: false # don't cancel all jobs on failure
45+
matrix:
46+
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
47+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
48+
python: ["3.11"]
49+
artifact: ["jax", "jaxlib", "jax-cuda-plugin", "jax-cuda-pjrt"]
50+
with:
51+
runner: ${{ matrix.runner }}
52+
python: ${{ matrix.python }}
53+
artifact: ${{ matrix.artifact }}
54+
upload_artifacts_to_gcs: false
55+
56+
build-gpu-wheels:
57+
uses: ./.github/workflows/bazel_build_wheels.yml
58+
name: "Build JAX GPU wheels"
59+
strategy:
60+
fail-fast: false # don't cancel all jobs on failure
61+
matrix:
62+
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
63+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
64+
python: ["3.11"]
65+
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
66+
with:
67+
runner: ${{ matrix.runner }}
68+
python: ${{ matrix.python }}
69+
artifact: ${{ matrix.artifact }}
70+
upload_artifacts_to_gcs: false
71+
72+
# build-jax-artifact:
73+
# uses: ./.github/workflows/build_artifacts.yml
74+
# name: "Build jax artifact"
75+
# with:
76+
# # Note that since jax is a pure python package, the runner OS and Python values do not
77+
# # matter. In addition, cloning main XLA also has no effect.
78+
# runner: "linux-x86-n2-16"
79+
# artifact: "jax"
80+
# upload_artifacts_to_gcs: false
81+
82+
# build-jaxlib-artifact:
83+
# uses: ./.github/workflows/build_artifacts.yml
84+
# strategy:
85+
# fail-fast: false # don't cancel all jobs on failure
86+
# matrix:
87+
# # Runner OS and Python values need to match the matrix stategy in the CPU tests job
88+
# runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
89+
# artifact: ["jaxlib"]
90+
# python: ["3.11"]
91+
# # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
92+
# # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
93+
# # values to the name and creates a separate entry for each matrix combination.
94+
# name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
95+
# with:
96+
# runner: ${{ matrix.runner }}
97+
# artifact: ${{ matrix.artifact }}
98+
# python: ${{ matrix.python }}
99+
# clone_main_xla: 1
100+
# upload_artifacts_to_gcs: false
101+
102+
run-bazel-test-cpu-py-import:
103+
uses: ./.github/workflows/bazel_cpu_rbe_no_jaxlib_build.yml
104+
strategy:
105+
fail-fast: false # don't cancel all jobs on failure
106+
matrix:
107+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
108+
python: ["3.11"]
109+
enable-x64: [0]
110+
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
111+
with:
112+
runner: ${{ matrix.runner }}
113+
python: ${{ matrix.python }}
114+
enable-x64: ${{ matrix.enable-x64 }}
115+
build_jaxlib: "wheel"
116+
run-backend_independent-test-suite: "0"

0 commit comments

Comments
 (0)