Skip to content

Commit 8fb888c

Browse files
committed
enabel pytest gpu workflow
1 parent c65cbd1 commit 8fb888c

File tree

1 file changed

+40
-30
lines changed

1 file changed

+40
-30
lines changed

.github/workflows/pytest_gpu.yml

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ jobs:
2121
matrix:
2222
python: ["3.10"]
2323

24-
runs-on: "linux-x86-n2-16"
25-
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
24+
runs-on: "linux-x86-g2-48-l4-4gpu"
25+
container: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
2626

2727
name: "Pytest GPU (Build wheels on CUDA 12.3)"
2828
env:
@@ -42,34 +42,6 @@ jobs:
4242
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
4343
- name: Build jax-cuda-pjrt
4444
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
45-
46-
run_tests:
47-
needs: build_artifacts
48-
strategy:
49-
matrix:
50-
test_env: [
51-
{cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu",
52-
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
53-
{cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu",
54-
image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
55-
]
56-
python: ["3.10"]
57-
58-
runs-on: ${{ matrix.test_env.runner }}
59-
container:
60-
image: ${{ matrix.test_env.image }}
61-
62-
name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})"
63-
env:
64-
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
65-
66-
steps:
67-
- uses: actions/checkout@v3
68-
# Halt for testing
69-
- name: Wait For Connection
70-
uses: google-ml-infra/actions/ci_connection@main
71-
with:
72-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
7345
- name: Install pytest
7446
env:
7547
JAXCI_PYTHON: python${{ matrix.python }}
@@ -80,3 +52,41 @@ jobs:
8052
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
8153
- name: Run Pytest GPU tests
8254
run: ./ci/run_pytest_gpu.sh
55+
56+
# run_tests:
57+
# needs: build_artifacts
58+
# strategy:
59+
# matrix:
60+
# test_env: [
61+
# {cuda_version: "12.3", runner: "linux-x86-g2-48-l4-4gpu",
62+
# image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
63+
# {cuda_version: "12.1", runner: "linux-x86-g2-48-l4-4gpu",
64+
# image: "gcr.io/tensorflow-testing/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"},
65+
# ]
66+
# python: ["3.10"]
67+
68+
# runs-on: ${{ matrix.test_env.runner }}
69+
# container:
70+
# image: ${{ matrix.test_env.image }}
71+
72+
# name: "Pytest GPU (Test on CUDA ${{ matrix.test_env.cuda_version }})"
73+
# env:
74+
# JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
75+
76+
# steps:
77+
# - uses: actions/checkout@v3
78+
# # Halt for testing
79+
# - name: Wait For Connection
80+
# uses: google-ml-infra/actions/ci_connection@main
81+
# with:
82+
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
83+
# - name: Install pytest
84+
# env:
85+
# JAXCI_PYTHON: python${{ matrix.python }}
86+
# run: $JAXCI_PYTHON -m pip install pytest
87+
# - name: Install dependencies
88+
# env:
89+
# JAXCI_PYTHON: python${{ matrix.python }}
90+
# run: $JAXCI_PYTHON -m pip install -r build/requirements.in
91+
# - name: Run Pytest GPU tests
92+
# run: ./ci/run_pytest_gpu.sh

0 commit comments

Comments
 (0)