Skip to content

Commit e43a85e

Browse files
committed
Reuse the Build artifact workflow
1 parent 6b7720f commit e43a85e

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed
Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Run Bazel GPU tests (non RBE)
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -16,14 +16,22 @@ on:
1616
- 'no'
1717

1818
jobs:
19-
build:
20-
strategy:
21-
matrix:
22-
runner: ["linux-x86-g2-48-l4-4gpu"]
23-
24-
runs-on: ${{ matrix.runner }}
25-
container:
26-
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
19+
build_artifacts:
20+
name: "Build the jaxlib and CUDA plugins using latest XLA"
21+
uses: ./.github/workflows/build_artifacts.yml
22+
with:
23+
wheel_list: "jaxlib,jax-cuda-plugin,jax-cuda-pjrt"
24+
python_list: "3.11"
25+
platform_list: "linux_x86"
26+
clone_main_xla: 1
27+
upload_artifacts: true
28+
upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
29+
30+
run_bazel_tests:
31+
name: "Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)"
32+
needs: build_artifacts
33+
runs-on: "linux-x86-g2-48-l4-4gpu"
34+
container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
2735

2836
env:
2937
JAXCI_HERMETIC_PYTHON_VERSION: 3.11
@@ -35,17 +43,12 @@ jobs:
3543
uses: google-ml-infra/actions/ci_connection@main
3644
with:
3745
halt-dispatch-input: ${{ inputs.halt-for-connection }}
38-
- name: Build jaxlib
39-
env:
40-
JAXCI_CLONE_MAIN_XLA: 1
41-
run: ./ci/build_artifacts.sh "jaxlib"
42-
- name: Build jax-cuda-plugin
43-
env:
44-
JAXCI_CLONE_MAIN_XLA: 1
45-
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
46-
- name: Build jax-cuda-pjrt
47-
env:
48-
JAXCI_CLONE_MAIN_XLA: 1
49-
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
46+
- name: Set Platform
47+
run: |
48+
os=$(uname -s | awk '{print tolower($0)}')
49+
arch=$(uname -m)
50+
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV
51+
- name: Download the artifacts built in the "build_artifacts" job
52+
run: mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/
5053
- name: Run Bazel GPU tests locally
5154
run: ./ci/run_bazel_test_gpu_non_rbe.sh

0 commit comments

Comments
 (0)