Skip to content

Commit 2014eb0

Browse files
committed
fix syntax issue
1 parent 9fa26dd commit 2014eb0

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

.github/workflows/wheel_tests_continuous.yml

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ on:
2222
- cron: "0 */2 * * *" # Run once every 2 hours
2323
pull_request:
2424
branches:
25-
- main
25+
- main
2626

2727
concurrency:
2828
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
2929
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
3030

3131
jobs:
32-
build_jaxlib_artifact:
32+
build-jaxlib-artifact:
3333
uses: ./.github/workflows/build_artifacts.yml
3434
strategy:
3535
fail-fast: false # don't cancel all jobs on failure
@@ -46,12 +46,12 @@ jobs:
4646
upload_artifacts_to_gcs: true
4747
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
4848

49-
build_cuda_artifacts:
50-
uses: ./.github/workflows/build_artifacts.yml
49+
build-cuda-artifacts:
50+
uses: ./.github/workflows/build_artifacts.yml
5151
strategy:
5252
fail-fast: false # don't cancel all jobs on failure
5353
matrix:
54-
# Python values need to match the matrix stategy in the GPU tests job below
54+
# Python values need to match the matrix stategy in the CUDA tests job below
5555
runner: ["linux-x86-n2-16"]
5656
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
5757
python: ["3.10",]
@@ -62,48 +62,54 @@ jobs:
6262
clone_main_xla: 1
6363
upload_artifacts_to_gcs: true
6464
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
65-
66-
run_pytest_cpu:
67-
needs: build_jaxlib_artifact
68-
uses: ./.github/workflows/pytest_cpu.yml
65+
66+
run-pytest-cpu:
67+
needs: build-jaxlib-artifact
68+
uses: ./.github/workflows/pytest_cpu.yml
6969
strategy:
7070
fail-fast: false # don't cancel all jobs on failure
7171
matrix:
72-
# Runner OS and Python values need to match the matrix stategy in the build_jaxlib_artifact job above
72+
# Runner OS and Python values need to match the matrix stategy in the
73+
# build_jaxlib_artifact job above
7374
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
7475
python: ["3.10",]
7576
enable-x64: [1, 0]
7677
with:
7778
runner: ${{ matrix.runner }}
7879
python: ${{ matrix.python }}
7980
enable-x64: ${{ matrix.enable-x64 }}
80-
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
81-
install_jax_current_commit: "1"
82-
81+
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
8382

84-
run_pytest_gpu:
85-
needs: [build_jaxlib_artifact, build_cuda_artifacts]
86-
uses: ./.github/workflows/pytest_cuda.yml
83+
run-pytest-cuda:
84+
needs: [build-jaxlib-artifact, build-cuda-artifacts]
85+
uses: ./.github/workflows/pytest_cuda.yml
8786
strategy:
8887
fail-fast: false # don't cancel all jobs on failure
8988
matrix:
90-
# Python values need to match the matrix stategy in the build artifacts job above
91-
runner: ["linux-x86-g2-48-l4-4gpu",]
89+
# Python values need to match the matrix stategy in the artifact build jobs above
90+
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
9291
python: ["3.10",]
9392
cuda: ["12.3", "12.1"]
9493
enable-x64: [1, 0]
94+
exclude:
95+
# Run only a single configuration on H100 to save resources
96+
- runner: "linux-x86-a3-8g-h100-8gpu"
97+
python: "3.10"
98+
cuda: "12.1"
99+
- runner: "linux-x86-a3-8g-h100-8gpu"
100+
python: "3.10"
101+
enable-x64: 0
95102
with:
96103
runner: ${{ matrix.runner }}
97104
python: ${{ matrix.python }}
98105
cuda: ${{ matrix.cuda }}
99106
enable-x64: ${{ matrix.enable-x64 }}
100-
# GCS upload URI is the same for both artifact build jobs
101-
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
102-
install_jax_current_commit: "1"
107+
# GCS upload URI is the same for both artifact build jobs
108+
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
103109

104-
run_bazel_test_gpu:
105-
needs: [build_jaxlib_artifact, build_cuda_artifacts]
106-
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
110+
run-bazel-test-cuda:
111+
needs: [build-jaxlib-artifact, build-cuda-artifacts]
112+
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
107113
strategy:
108114
fail-fast: false # don't cancel all jobs on failure
109115
matrix:
@@ -115,4 +121,5 @@ jobs:
115121
runner: ${{ matrix.runner }}
116122
python: ${{ matrix.python }}
117123
enable-x64: ${{ matrix.enable-x64 }}
118-
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
124+
# GCS upload URI is the same for both artifact build jobs
125+
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

0 commit comments

Comments
 (0)