Skip to content

Commit 11fa64f

Browse files
Add wheel build and py_import tests in JAX presubmit.
The jobs are needed to replicate postsubmit Github Workflow actions and catch the majority of the issues in the presubmit. The tests do not block PRs yet. PiperOrigin-RevId: 795222280
1 parent b6996cb commit 11fa64f

16 files changed

+555
-81
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# CI - Bazel build wheels (RBE)
2+
#
3+
# This workflow builds all jax wheels using Bazel.
4+
#
5+
# It consists of the following jobs:
6+
# build-jax-artifact:
7+
# - Builds the jax artifact from source via build.py invocation. Uses build_artifacts.yml.
8+
# build-jaxlib-artifact:
9+
# - Builds the jaxlib artifact from source via build.py invocation. Uses build_artifacts.yml.
10+
# build-cuda-artifacts:
11+
# - Builds the cuda artifacts from source via build.py invocation. Uses build_artifacts.yml.
12+
13+
name: CI - Bazel build wheels (RBE)
14+
15+
on:
16+
workflow_dispatch: # allows triggering the workflow run manually
17+
inputs:
18+
runners:
19+
description: "Which runners should the workflow run on? Format example: '[\"linux-x86-n4-16\", \"linux-arm64-c4a-16\"]'"
20+
type: string
21+
default: '["linux-x86-n4-16"]'
22+
artifacts:
23+
description: "Which JAX artifacts to build? Format example: '[\"jaxlib\", \"jax\"]'"
24+
type: string
25+
default: '["jaxlib"]'
26+
python-versions:
27+
description: "Which python versions should the artifact be built for? Format example: '[\"3.12\", \"3.13\"]'"
28+
type: string
29+
default: '["3.12"]'
30+
cuda-versions:
31+
description: "Which cuda versions should the artifact be built for? Format example: '[\"12\", \"13\"]'"
32+
type: string
33+
default: '["12"]'
34+
clone_main_xla:
35+
description: "Should latest XLA be used?"
36+
type: choice
37+
default: "0"
38+
options:
39+
- "1"
40+
- "0"
41+
halt-for-connection:
42+
description: 'Should this workflow run wait for a remote connection?'
43+
type: choice
44+
default: 'no'
45+
options:
46+
- 'yes'
47+
- 'no'
48+
workflow_call:
49+
inputs:
50+
runners:
51+
description: "Which runners should the workflow run on?"
52+
type: string
53+
default: '["linux-x86-n2-16"]'
54+
artifacts:
55+
description: "Which JAX artifacts to build?"
56+
type: string
57+
default: '["jaxlib"]'
58+
python-versions:
59+
description: "Which python versions should the artifact be built for?"
60+
type: string
61+
default: '["3.12"]'
62+
cuda-versions:
63+
description: "Which cuda versions should the artifact be built for?"
64+
type: string
65+
default: '["12"]'
66+
upload_artifacts_to_gcs:
67+
description: "Should the artifacts be uploaded to a GCS bucket?"
68+
default: true
69+
type: boolean
70+
gcs_upload_uri:
71+
description: "GCS location prefix to where the artifacts should be uploaded"
72+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
73+
type: string
74+
clone_main_xla:
75+
description: "Should latest XLA be used?"
76+
type: string
77+
default: "0"
78+
outputs:
79+
gcs_upload_uri:
80+
description: "GCS location prefix to where the artifacts were uploaded"
81+
value: ${{ jobs.build-jax-artifact.outputs.gcs_upload_uri }}
82+
83+
permissions: {}
84+
85+
jobs:
86+
build-jax-artifact:
87+
if: ${{ contains(fromJSON(inputs.artifacts), 'jax') && contains(fromJSON(inputs.runners), 'linux-x86-n4-16') }}
88+
uses: ./.github/workflows/build_artifacts.yml
89+
name: "Build jax artifact"
90+
strategy:
91+
fail-fast: false # don't cancel all jobs on failure
92+
with:
93+
# Note that since jax is a pure python package, the runner OS and Python values do not
94+
# matter. In addition, cloning main XLA also has no effect.
95+
runner: "linux-x86-n4-16"
96+
python: "3.12"
97+
artifact: "jax"
98+
clone_main_xla: 0
99+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
100+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
101+
102+
build-jaxlib-artifact:
103+
if: ${{ contains(fromJSON(inputs.artifacts), 'jaxlib') }}
104+
uses: ./.github/workflows/build_artifacts.yml
105+
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
106+
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
107+
# values to the name and creates a separate entry for each matrix combination.
108+
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
109+
strategy:
110+
fail-fast: false # don't cancel all jobs on failure
111+
matrix:
112+
runner: ${{ fromJSON(inputs.runners) }}
113+
python: ${{ fromJSON(inputs.python-versions) }}
114+
exclude:
115+
- runner: "windows-x86-n2-16"
116+
python: "3.14-nogil"
117+
- runner: "windows-x86-n2-16"
118+
python: "3.13-nogil"
119+
with:
120+
runner: ${{ matrix.runner }}
121+
artifact: "jaxlib"
122+
python: ${{ matrix.python }}
123+
clone_main_xla: ${{ inputs.clone_main_xla }}
124+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
125+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}
126+
127+
build-cuda-artifacts:
128+
if: ${{ (contains(fromJSON(inputs.artifacts), 'jax-cuda-plugin') || contains(fromJSON(inputs.artifacts), 'jax-cuda-pjrt')) }}
129+
uses: ./.github/workflows/build_artifacts.yml
130+
strategy:
131+
fail-fast: false # don't cancel all jobs on failure
132+
matrix:
133+
runner: ${{ fromJSON(inputs.runners) }}
134+
python: ${{ fromJSON(inputs.python-versions) }}
135+
artifact: ${{ fromJSON(inputs.artifacts) }}
136+
cuda-version: ${{ fromJSON(inputs.cuda-versions) }}
137+
exclude:
138+
- runner: "windows-x86-n2-16"
139+
name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
140+
with:
141+
runner: ${{ matrix.runner }}
142+
artifact: ${{ matrix.artifact }}
143+
python: ${{ matrix.python }}
144+
cuda-version: ${{ matrix.cuda-version }}
145+
clone_main_xla: ${{ inputs.clone_main_xla }}
146+
upload_artifacts_to_gcs: ${{ inputs.upload_artifacts_to_gcs }}
147+
gcs_upload_uri: ${{ inputs.gcs_upload_uri }}

.github/workflows/bazel_cpu.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ on:
4242
description: 'Should jax be built from source?'
4343
required: true
4444
type: string
45+
run-backend_independent-test-suite:
46+
description: 'Should backend_independent test suite be run?'
47+
required: true
48+
type: string
49+
clone_main_xla:
50+
description: "Should latest XLA be used?"
51+
required: true
52+
type: string
53+
default: "0"
4554
gcs_download_uri:
4655
description: "GCS location prefix from where the artifacts should be downloaded"
4756
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
@@ -62,6 +71,8 @@ jobs:
6271
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
6372
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
6473
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
74+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: ${{ inputs.run-backend_independent-test-suite }}
75+
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
6576

6677
# Begin Presubmit Naming Check - name modification requires internal check to be updated
6778
name: "Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"

.github/workflows/bazel_cpu_presubmit.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,5 @@ jobs:
5353
enable-x64: ${{ matrix.enable-x64 }}
5454
build_jaxlib: 'true'
5555
build_jax: 'true'
56+
run-backend_independent-test-suite: 1
57+
clone_main_xla: 1

.github/workflows/bazel_cuda.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ on:
7575
required: false
7676
default: 'false'
7777
type: string
78+
run-backend_independent-test-suite:
79+
description: 'Should backend_independent test suite be run?'
80+
required: true
81+
type: string
82+
clone_main_xla:
83+
description: "Should latest XLA be used?"
84+
required: true
85+
type: string
86+
default: "0"
7887
halt-for-connection:
7988
description: 'Should this workflow run wait for a remote connection?'
8089
type: string
@@ -96,6 +105,8 @@ jobs:
96105
JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE: ${{ inputs.write_to_bazel_remote_cache }}
97106
JAXCI_BUILD_JAX: ${{ inputs.build_jax }}
98107
JAXCI_BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
108+
JAXCI_RUN_BACKEND_INDEPENDENT_TESTS: ${{ inputs.run-backend_independent-test-suite }}
109+
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
99110
# Begin Presubmit Naming Check - name modification requires internal check to be updated
100111
name: "jaxlib=${{ inputs.jaxlib-version }}, CUDA=${{ inputs.cuda-version }}, Python=${{ inputs.python }}, x64=${{ inputs.enable-x64 }}, build_jax=${{ inputs.build_jax }}, build_jaxlib=${{ inputs.build_jaxlib }}"
101112
# End Presubmit Naming Check github-cuda-presubmits

.github/workflows/bazel_cuda_presubmit.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,5 @@ jobs:
6868
build_jaxlib: ${{ (matrix.jaxlib-version == 'head' && true) || false }}
6969
build_jax: 'true'
7070
run_multiaccelerator_tests: 'false'
71+
run-backend_independent-test-suite: 1
72+
clone_main_xla: 1
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# CI - Bazel py_import tests (RBE)
2+
#
3+
# This workflow is triggered only by other workflows.
4+
#
5+
# It consists of the following jobs:
6+
# run-bazel-test-cpu-py-import:
7+
# - Runs the Bazel CPU tests with py_import dependencies. Uses bazel_cpu_rbe_no_jaxlib_build.yml.
8+
# run-bazel-test-cuda-py-import:
9+
# - Runs the Bazel CUDA tests with py_import dependencies. Uses bazel_cuda_non_rbe.yml.
10+
name: CI - Bazel py_import tests (RBE)
11+
12+
on:
13+
workflow_call:
14+
inputs:
15+
# runners:
16+
# description: "Which runners should the workflow run on?"
17+
# type: string
18+
# default: '["linux-x86-n2-16"]'
19+
python-versions:
20+
description: "Which python versions should the artifact be built for?"
21+
type: string
22+
default: '["3.12"]'
23+
cuda-versions:
24+
description: "Which cuda versions should the artifact be built for?"
25+
type: string
26+
default: '["12"]'
27+
enable-x64:
28+
description: "Should x64 mode be enabled?"
29+
type: string
30+
default: '["0"]'
31+
jaxlib-versions:
32+
description: "Which jaxlib versions to test? (head/pypi_latest)"
33+
type: string
34+
default: '["head"]'
35+
clone_main_xla:
36+
description: "Should latest XLA be used?"
37+
type: string
38+
default: "0"
39+
run-backend_independent-test-suite:
40+
description: "Should backend independent test suite be run?"
41+
type: string
42+
default: "1"
43+
44+
permissions: {}
45+
46+
jobs:
47+
run-bazel-test-cpu-py-import:
48+
uses: ./.github/workflows/bazel_cpu_rbe_no_jaxlib_build.yml
49+
strategy:
50+
fail-fast: false # don't cancel all jobs on failure
51+
matrix:
52+
runner: ["linux-x86-n4-16", "linux-arm64-c4a-16", "windows-x86-n2-16"]
53+
python: ${{ fromJSON(inputs.python-versions) }}
54+
enable-x64: ${{ fromJSON(inputs.enable-x64) }}
55+
name: "Bazel CPU tests with build_jaxlib=wheel (${{ matrix.runner }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x64 }})"
56+
with:
57+
runner: ${{ matrix.runner }}
58+
python: ${{ matrix.python }}
59+
enable-x64: ${{ matrix.enable-x64 }}
60+
build_jaxlib: "wheel"
61+
run-backend_independent-test-suite: ${{ inputs.run-backend_independent-test-suite }}
62+
clone_main_xla: ${{ inputs.clone_main_xla }}
63+
64+
# run-bazel-test-cuda-py-import:
65+
# uses: ./.github/workflows/bazel_cuda_non_rbe.yml
66+
# strategy:
67+
# fail-fast: false # don't cancel all jobs on failure
68+
# matrix:
69+
# runner: ["linux-x86-g2-48-l4-4gpu"]
70+
# python: ${{ fromJSON(inputs.python-versions) }}
71+
# enable-x64: ${{ fromJSON(inputs.enable-x64) }}
72+
# cuda-version: ${{ fromJSON(inputs.cuda-versions) }}
73+
# jaxlib-version: ${{ fromJSON(inputs.jaxlib-versions) }}
74+
# name: "Bazel CUDA Non-RBE with build_jaxlib=wheel (${{ matrix.runner }}, CUDA=${{ matrix.cuda-version }}, jaxlib=${{ matrix.jaxlib-version }}, Python ${{ matrix.python }}, x64=${{ matrix.enable-x64 }})"
75+
# with:
76+
# runner: ${{ matrix.runner }}
77+
# python: ${{ matrix.python }}
78+
# cuda-version: ${{ matrix.cuda-version }}
79+
# enable-x64: ${{ matrix.enable-x64 }}
80+
# build_jaxlib: "wheel"
81+
# jaxlib-version: ${{ matrix.jaxlib-version }}
82+
# run-backend_independent-test-suite: ${{ inputs.run-backend_independent-test-suite }}
83+
# clone_main_xla: ${{ inputs.clone_main_xla }}

.github/workflows/build_artifacts.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ jobs:
125125
with:
126126
persist-credentials: false
127127
- name: Enable RBE if building on Linux x86 or Windows x86
128-
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
128+
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') || contains(inputs.runner, 'linux-arm64')
129129
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
130130
- name: Set bazel output base on Windows runner machine
131131
if: contains(inputs.runner, 'windows-x86')

0 commit comments

Comments
 (0)