Skip to content

Commit 2e83e5f

Browse files
committed
split build and test jobs
1 parent 250adf0 commit 2e83e5f

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

.github/workflows/pytest_gpu.yml

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,35 @@ on:
1616
- 'no'
1717

1818
jobs:
19-
run_pytest_gpu:
19+
build_artifacts:
20+
strategy:
21+
matrix:
22+
python: ["3.10"]
23+
24+
runs-on: "linux-x86-n2-16"
25+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
26+
27+
name: "Pytest GPU (Build wheels on CUDA 12.3)"
28+
env:
29+
JAXCI_CLONE_MAIN_XLA: 1
30+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
31+
32+
steps:
33+
- uses: actions/checkout@v3
34+
# Halt for testing
35+
- name: Wait For Connection
36+
uses: google-ml-infra/actions/ci_connection@main
37+
with:
38+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
39+
- name: Build jaxlib
40+
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
41+
- name: Build jax-cuda-plugin
42+
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env"
43+
- name: Build jax-cuda-pjrt
44+
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env"
45+
46+
run_tests:
47+
needs: build_artifacts
2048
strategy:
2149
matrix:
2250
test_env: [
@@ -31,9 +59,8 @@ jobs:
3159
container:
3260
image: ${{ matrix.test_env.image }}
3361

34-
name: "Pytest GPU (Build wheels on CUDA 12.3 and test on CUDA ${{ matrix.test_env.cuda_version }})"
62+
name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})"
3563
env:
36-
JAXCI_CLONE_MAIN_XLA: 1
3764
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
3865

3966
steps:
@@ -43,12 +70,6 @@ jobs:
4370
uses: google-ml-infra/actions/ci_connection@main
4471
with:
4572
halt-dispatch-input: ${{ inputs.halt-for-connection }}
46-
- name: Build jaxlib
47-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env"
48-
- name: Build jax-cuda-plugin
49-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env"
50-
- name: Build jax-cuda-pjrt
51-
run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env"
5273
- name: Install pytest
5374
env:
5475
JAXCI_PYTHON: python${{ matrix.python }}

0 commit comments

Comments
 (0)