Skip to content

Commit 1b9a5fa

Browse files
Add CPU wheel build and tests in JAX presubmit.
The jobs are needed to replicate postsubmit Github Workflow actions and catch the majority of the issues in the presubmit. The tests do not block PRs yet. PiperOrigin-RevId: 795222280
1 parent 357948b commit 1b9a5fa

File tree

9 files changed

+353
-60
lines changed

9 files changed

+353
-60
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# CI - Bazel build wheels (RBE)
2+
#
3+
# This workflow builds all jax wheels using Bazel.
4+
#
5+
# It consists of the following jobs:
6+
# build-jax-artifact:
7+
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
8+
# build-jaxlib-artifact:
9+
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
10+
# build-cuda-artifacts:
11+
# - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml.
12+
13+
name: CI - Bazel build wheels (RBE)
14+
15+
on:
16+
workflow_dispatch: # allows triggering the workflow run manually
17+
inputs:
18+
runner:
19+
description: "Which runner should the workflow run on?"
20+
type: choice
21+
default: "linux-x86-n2-16"
22+
options:
23+
- "linux-x86-n2-16"
24+
- "linux-arm64-t2a-48"
25+
- "linux-arm64-c4a-16"
26+
- "windows-x86-n2-16"
27+
python:
28+
description: "Which python version should the artifact be built for?"
29+
type: choice
30+
default: "3.12"
31+
options:
32+
- "3.11"
33+
- "3.12"
34+
- "3.13"
35+
- "3.14"
36+
- "3.13-nogil"
37+
- "3.14-nogil"
38+
upload_artifacts_to_gcs:
39+
description: "Should the artifacts be uploaded to a GCS bucket?"
40+
default: true
41+
type: boolean
42+
gcs_upload_uri:
43+
description: "GCS location prefix to where the artifacts should be uploaded"
44+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
45+
type: string
46+
clone_main_xla:
47+
description: "Should latest XLA be used?"
48+
type: choice
49+
default: "0"
50+
options:
51+
- "1"
52+
- "0"
53+
halt-for-connection:
54+
description: 'Should this workflow run wait for a remote connection?'
55+
type: choice
56+
default: 'no'
57+
options:
58+
- 'yes'
59+
- 'no'
60+
workflow_call:
61+
inputs:
62+
runner:
63+
description: "Which runner should the workflow run on?"
64+
type: string
65+
options:
66+
- "linux-x86-n2-16"
67+
- "linux-arm64-t2a-48"
68+
- "linux-arm64-c4a-16"
69+
- "windows-x86-n2-16"
70+
python:
71+
description: "Which python version should the artifact be built for?"
72+
type: string
73+
default: "3.12"
74+
upload_artifacts_to_gcs:
75+
description: "Should the artifacts be uploaded to a GCS bucket?"
76+
default: true
77+
type: boolean
78+
gcs_upload_uri:
79+
description: "GCS location prefix to where the artifacts should be uploaded"
80+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
81+
type: string
82+
clone_main_xla:
83+
description: "Should latest XLA be used?"
84+
type: string
85+
default: "0"
86+
87+
permissions: {}
88+
89+
jobs:
90+
build-jax-artifact:
91+
if: ${{ startsWith(inputs.runner, 'linux-x86')}}
92+
uses: ./.github/workflows/build_artifacts.yml
93+
name: "Build jax artifact"
94+
with:
95+
# Note that since jax is a pure python package, the runner OS and Python values do not
96+
# matter. In addition, cloning main XLA also has no effect.
97+
runner: ${{ inputs.runner }}
98+
artifact: "jax"
99+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
100+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
101+
102+
build-jaxlib-artifact:
103+
uses: ./.github/workflows/build_artifacts.yml
104+
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
105+
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
106+
# values to the name and creates a separate entry for each matrix combination.
107+
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
108+
with:
109+
runner: ${{ inputs.runner }}
110+
artifact: "jaxlib"
111+
python: ${{ inputs.python }}
112+
clone_main_xla: ${{ inputs.clone_main_xla }}
113+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
114+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
115+
116+
build-cuda-artifacts:
117+
if: ${{ startsWith(inputs.runner, 'linux')}}
118+
uses: ./.github/workflows/build_artifacts.yml
119+
strategy:
120+
fail-fast: false # don't cancel all jobs on failure
121+
matrix:
122+
# Python values need to match the matrix stategy in the CUDA tests job below
123+
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
124+
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
125+
with:
126+
runner: ${{ inputs.runner }}
127+
artifact: ${{ matrix.artifact }}
128+
python: ${{ inputs.python }}
129+
clone_main_xla: ${{ inputs.clone_main_xla }}
130+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
131+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}

.github/workflows/bazel_cpu_rbe.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
3535
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
3636
JAXCI_BUILD_JAXLIB: "true"
37+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: "1"
3738
# Begin Presubmit Naming Check - name modification requires internal check to be updated
3839
strategy:
3940
matrix:

.github/workflows/bazel_cpu_rbe_no_jaxlib_build.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ on:
3838
required: false
3939
default: 'false'
4040
type: string
41+
run-backend_independent-test-suite:
42+
description: "Should the backend independent test suite be run?"
43+
type: string
44+
default: "1"
4145
gcs_download_uri:
4246
description: "GCS location prefix from where the artifacts should be downloaded"
4347
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
@@ -57,6 +61,7 @@ jobs:
5761
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
5862
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
5963
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
64+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: "${{ inputs.run-backend_independent-test-suite }}"
6065

6166
name: "${{ (contains(inputs.runner, 'linux-x86') && 'linux x86') ||
6267
(contains(inputs.runner, 'linux-arm64') && 'linux arm64') ||
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# CI - Bazel wheel tests (RBE)
2+
#
3+
# This workflow is triggered in the presubmit.
4+
#
5+
# It consists of the following jobs:
6+
# build-jax-artifact:
7+
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
8+
# build-jaxlib-artifact:
9+
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
10+
# run-bazel-test-cpu-py-import:
11+
# - Runs the Bazel CPU tests with py_import dependencies. Uses bazel_cpu_rbe_no_jaxlib_build.yml.
12+
13+
name: CI - Bazel wheel tests (RBE)
14+
15+
on:
16+
workflow_dispatch:
17+
inputs:
18+
halt-for-connection:
19+
description: 'Should this workflow run wait for a remote connection?'
20+
type: choice
21+
required: true
22+
default: 'no'
23+
options:
24+
- 'yes'
25+
- 'no'
26+
pull_request:
27+
branches:
28+
- main
29+
push:
30+
branches:
31+
- main
32+
- 'release/**'
33+
permissions: {}
34+
concurrency:
35+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
36+
# Don't cancel in-progress jobs for main/release branches.
37+
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
38+
39+
jobs:
40+
build-wheels:
41+
uses: ./.github/workflows/bazel_build_wheels.yml
42+
name: "Build jax wheels"
43+
strategy:
44+
fail-fast: false # don't cancel all jobs on failure
45+
matrix:
46+
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
47+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
48+
python: ["3.11"]
49+
with:
50+
runner: ${{ matrix.runner }}
51+
python: ${{ matrix.python }}
52+
upload_artifacts_to_gcs: false
53+
54+
# build-jax-artifact:
55+
# uses: ./.github/workflows/build_artifacts.yml
56+
# name: "Build jax artifact"
57+
# with:
58+
# # Note that since jax is a pure python package, the runner OS and Python values do not
59+
# # matter. In addition, cloning main XLA also has no effect.
60+
# runner: "linux-x86-n2-16"
61+
# artifact: "jax"
62+
# upload_artifacts_to_gcs: false
63+
64+
# build-jaxlib-artifact:
65+
# uses: ./.github/workflows/build_artifacts.yml
66+
# strategy:
67+
# fail-fast: false # don't cancel all jobs on failure
68+
# matrix:
69+
# # Runner OS and Python values need to match the matrix stategy in the CPU tests job
70+
# runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
71+
# artifact: ["jaxlib"]
72+
# python: ["3.11"]
73+
# # Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
74+
# # dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
75+
# # values to the name and creates a separate entry for each matrix combination.
76+
# name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
77+
# with:
78+
# runner: ${{ matrix.runner }}
79+
# artifact: ${{ matrix.artifact }}
80+
# python: ${{ matrix.python }}
81+
# clone_main_xla: 1
82+
# upload_artifacts_to_gcs: false
83+
84+
run-bazel-test-cpu-py-import:
85+
uses: ./.github/workflows/bazel_cpu_rbe_no_jaxlib_build.yml
86+
strategy:
87+
fail-fast: false # don't cancel all jobs on failure
88+
matrix:
89+
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
90+
python: ["3.11"]
91+
enable-x64: [0]
92+
name: "Bazel CPU tests with ${{ format('{0}', 'build_jaxlib=wheel') }}"
93+
with:
94+
runner: ${{ matrix.runner }}
95+
python: ${{ matrix.python }}
96+
enable-x64: ${{ matrix.enable-x64 }}
97+
build_jaxlib: "wheel"
98+
run-backend_independent-test-suite: "0"

0 commit comments

Comments
 (0)