2222 - cron : " 0 */2 * * *" # Run once every 2 hours
2323 pull_request :
2424 branches :
25- - main
25+ - main
2626
2727concurrency :
2828 group : ${{ github.workflow }}-${{ github.head_ref || github.ref }}
2929 cancel-in-progress : ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
3030
3131jobs :
32- build_jaxlib_artifact :
32+ build-jaxlib-artifact :
3333 uses : ./.github/workflows/build_artifacts.yml
3434 strategy :
3535 fail-fast : false # don't cancel all jobs on failure
@@ -46,12 +46,12 @@ jobs:
4646 upload_artifacts_to_gcs : true
4747 gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
4848
49- build_cuda_artifacts :
50- uses : ./.github/workflows/build_artifacts.yml
49+ build-cuda-artifacts :
50+ uses : ./.github/workflows/build_artifacts.yml
5151 strategy :
5252 fail-fast : false # don't cancel all jobs on failure
5353 matrix :
54- # Python values need to match the matrix stategy in the GPU tests job below
54+ # Python values need to match the matrix stategy in the CUDA tests job below
5555 runner : ["linux-x86-n2-16"]
5656 artifact : ["jax-cuda-plugin", "jax-cuda-pjrt"]
5757 python : ["3.10",]
@@ -62,48 +62,54 @@ jobs:
6262 clone_main_xla : 1
6363 upload_artifacts_to_gcs : true
6464 gcs_upload_uri : ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
65-
66- run_pytest_cpu :
67- needs : build_jaxlib_artifact
68- uses : ./.github/workflows/pytest_cpu.yml
65+
66+ run-pytest-cpu :
67+ needs : build-jaxlib-artifact
68+ uses : ./.github/workflows/pytest_cpu.yml
6969 strategy :
7070 fail-fast : false # don't cancel all jobs on failure
7171 matrix :
72- # Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above
72+ # Runner OS and Python values need to match the matrix stategy in the
73+ # build_jaxlib_artifact job above
7374 runner : ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
7475 python : ["3.10",]
7576 enable-x64 : [1, 0]
7677 with :
7778 runner : ${{ matrix.runner }}
7879 python : ${{ matrix.python }}
7980 enable-x64 : ${{ matrix.enable-x64 }}
80- gcs_download_uri : ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
81- install_jax_current_commit : " 1"
82-
81+ gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
8382
84- run_pytest_gpu :
85- needs : [build_jaxlib_artifact, build_cuda_artifacts ]
86- uses : ./.github/workflows/pytest_cuda.yml
83+ run-pytest-cuda :
84+ needs : [build-jaxlib-artifact, build-cuda-artifacts ]
85+ uses : ./.github/workflows/pytest_cuda.yml
8786 strategy :
8887 fail-fast : false # don't cancel all jobs on failure
8988 matrix :
90- # Python values need to match the matrix stategy in the build artifacts job above
91- runner : ["linux-x86-g2-48-l4-4gpu",]
89+ # Python values need to match the matrix stategy in the artifact build jobs above
90+ runner : ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu" ]
9291 python : ["3.10",]
9392 cuda : ["12.3", "12.1"]
9493 enable-x64 : [1, 0]
94+ exclude :
95+ # Run only a single configuration on H100 to save resources
96+ - runner : " linux-x86-a3-8g-h100-8gpu"
97+ python : " 3.10"
98+ cuda : " 12.1"
99+ - runner : " linux-x86-a3-8g-h100-8gpu"
100+ python : " 3.10"
101+ enable-x64 : 0
95102 with :
96103 runner : ${{ matrix.runner }}
97104 python : ${{ matrix.python }}
98105 cuda : ${{ matrix.cuda }}
99106 enable-x64 : ${{ matrix.enable-x64 }}
100- # GCS upload URI is the same for both artifact build jobs
101- gcs_download_uri : ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
102- install_jax_current_commit : " 1"
107+ # GCS upload URI is the same for both artifact build jobs
108+ gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
103109
104- run_bazel_test_gpu :
105- needs : [build_jaxlib_artifact, build_cuda_artifacts ]
106- uses : ./.github/workflows/bazel_cuda_non_rbe.yml
110+ run-bazel-test-cuda :
111+ needs : [build-jaxlib-artifact, build-cuda-artifacts ]
112+ uses : ./.github/workflows/bazel_cuda_non_rbe.yml
107113 strategy :
108114 fail-fast : false # don't cancel all jobs on failure
109115 matrix :
@@ -115,4 +121,5 @@ jobs:
115121 runner : ${{ matrix.runner }}
116122 python : ${{ matrix.python }}
117123 enable-x64 : ${{ matrix.enable-x64 }}
118- gcs_download_uri : ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
124+ # GCS upload URI is the same for both artifact build jobs
125+ gcs_download_uri : ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
0 commit comments