Skip to content

Commit 250adf0

Browse files
committed
Add values for testing on CUDA 12.1 image
1 parent d977dca commit 250adf0

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

.github/workflows/pytest_gpu.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ on:
1616
- 'no'
1717

1818
jobs:
19-
build:
19+
run_pytest_gpu:
2020
strategy:
2121
matrix:
22-
runner: ["linux-x86-g2-48-l4-4gpu"]
22+
test_env: [
23+
{cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu",
24+
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
25+
{cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu",
26+
image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
27+
]
2328
python: ["3.10"]
2429

25-
runs-on: ${{ matrix.runner }}
30+
runs-on: ${{ matrix.test_env.runner }}
2631
container:
27-
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
28-
32+
image: ${{ matrix.test_env.image }}
33+
34+
name: "Pytest GPU (Build wheels on CUDA 12.3 and test on CUDA ${{ matrix.test_env.cuda_version }})"
2935
env:
3036
JAXCI_CLONE_MAIN_XLA: 1
3137
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}

0 commit comments

Comments
 (0)