11name : Run Bazel GPU tests (non RBE)
22
33on :
4- # pull_request:
5- # branches:
6- # - main
4+ pull_request :
5+ branches :
6+ - main
77 workflow_dispatch :
88 inputs :
99 halt-for-connection :
1616 - ' no'
1717
1818jobs :
19- build :
20- strategy :
21- matrix :
22- runner : ["linux-x86-g2-48-l4-4gpu"]
23-
24- runs-on : ${{ matrix.runner }}
25- container :
26- image : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
19+ build_artifacts :
20+ name : " Build the jaxlib and CUDA plugins using latest XLA"
21+ uses : ./.github/workflows/build_artifacts.yml
22+ with :
23+ wheel_list : " jaxlib,jax-cuda-plugin,jax-cuda-pjrt"
24+ python_list : " 3.11"
25+ platform_list : " linux_x86"
26+ clone_main_xla : 1
27+ upload_artifacts : true
28+ upload_destination : ' ${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
29+
30+ run_bazel_tests :
31+ name : " Run Bazel GPU tests (single accelerator and multi-accelerator tests, non-RBE)"
32+ needs : build_artifacts
33+ runs-on : " linux-x86-g2-48-l4-4gpu"
34+ container : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
2735
2836 env :
2937 JAXCI_HERMETIC_PYTHON_VERSION : 3.11
@@ -35,17 +43,12 @@ jobs:
3543 uses : google-ml-infra/actions/ci_connection@main
3644 with :
3745 halt-dispatch-input : ${{ inputs.halt-for-connection }}
38- - name : Build jaxlib
39- env :
40- JAXCI_CLONE_MAIN_XLA : 1
41- run : ./ci/build_artifacts.sh "jaxlib"
42- - name : Build jax-cuda-plugin
43- env :
44- JAXCI_CLONE_MAIN_XLA : 1
45- run : ./ci/build_artifacts.sh "jax-cuda-plugin"
46- - name : Build jax-cuda-pjrt
47- env :
48- JAXCI_CLONE_MAIN_XLA : 1
49- run : ./ci/build_artifacts.sh "jax-cuda-pjrt"
46+ - name : Set Platform
47+ run : |
48+ os=$(uname -s | awk '{print tolower($0)}')
49+ arch=$(uname -m)
50+ echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV
51+ - name : Download the artifacts built in the "build_artifacts" job
52+ run : mkdir -p $(pwd)/dist && gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }}/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION} $(pwd)/dist/
5053 - name : Run Bazel GPU tests locally
5154 run : ./ci/run_bazel_test_gpu_non_rbe.sh
0 commit comments