|
16 | 16 | - 'no' |
17 | 17 |
|
18 | 18 | jobs: |
19 | | - run_pytest_gpu: |
| 19 | + build_artifacts: |
| 20 | + strategy: |
| 21 | + matrix: |
| 22 | + python: ["3.10"] |
| 23 | + |
| 24 | + runs-on: "linux-x86-n2-16" |
| 25 | + container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" |
| 26 | + |
| 27 | + name: "Pytest GPU (Build wheels on CUDA 12.3)" |
| 28 | + env: |
| 29 | + JAXCI_CLONE_MAIN_XLA: 1 |
| 30 | + JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} |
| 31 | + |
| 32 | + steps: |
| 33 | + - uses: actions/checkout@v3 |
| 34 | + # Halt for testing |
| 35 | + - name: Wait For Connection |
| 36 | + uses: google-ml-infra/actions/ci_connection@main |
| 37 | + with: |
| 38 | + halt-dispatch-input: ${{ inputs.halt-for-connection }} |
| 39 | + - name: Build jaxlib |
| 40 | + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env" |
| 41 | + - name: Build jax-cuda-plugin |
| 42 | + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env" |
| 43 | + - name: Build jax-cuda-pjrt |
| 44 | + run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env" |
| 45 | + |
| 46 | + run_tests: |
| 47 | + needs: build_artifacts |
20 | 48 | strategy: |
21 | 49 | matrix: |
22 | 50 | test_env: [ |
|
31 | 59 | container: |
32 | 60 | image: ${{ matrix.test_env.image }} |
33 | 61 |
|
34 | | - name: "Pytest GPU (Build wheels on CUDA 12.3 and test on CUDA ${{ matrix.test_env.cuda_version }})" |
| 62 | + name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})" |
35 | 63 | env: |
36 | | - JAXCI_CLONE_MAIN_XLA: 1 |
37 | 64 | JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }} |
38 | 65 |
|
39 | 66 | steps: |
|
43 | 70 | uses: google-ml-infra/actions/ci_connection@main |
44 | 71 | with: |
45 | 72 | halt-dispatch-input: ${{ inputs.halt-for-connection }} |
46 | | - - name: Build jaxlib |
47 | | - run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jaxlib.env" |
48 | | - - name: Build jax-cuda-plugin |
49 | | - run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-plugin.env" |
50 | | - - name: Build jax-cuda-pjrt |
51 | | - run: ./ci/build_artifacts.sh "ci/envs/build_artifacts/jax-cuda-pjrt.env" |
52 | 73 | - name: Install pytest |
53 | 74 | env: |
54 | 75 | JAXCI_PYTHON: python${{ matrix.python }} |
|
0 commit comments