File tree Expand file tree Collapse file tree 1 file changed +63
-0
lines changed Expand file tree Collapse file tree 1 file changed +63
-0
lines changed Original file line number Diff line number Diff line change 1+ name : Benchmarks
2+
3+ on :
4+ # pull_request:
5+ # branches:
6+ # - main
7+ workflow_dispatch :
8+ inputs :
9+ halt-for-connection :
10+ description : ' Should this workflow run wait for a remote connection?'
11+ type : choice
12+ required : true
13+ default : ' no'
14+ options :
15+ - ' yes'
16+ - ' no'
17+
18+ jobs :
19+ build :
20+ strategy :
21+ matrix :
22+ runner : ["linux-x86-g2-48-l4-4gpu"]
23+
24+ runs-on : ${{ matrix.runner }}
25+ container :
26+ image : " gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
27+
28+ env :
29+ JAXCI_HERMETIC_PYTHON_VERSION : 3.11
30+
31+ steps :
32+ - uses : actions/checkout@v3
33+ # Halt for testing
34+ - name : Wait For Connection
35+ uses : google-ml-infra/actions/ci_connection@main
36+ with :
37+ halt-dispatch-input : ${{ inputs.halt-for-connection }}
38+ - name : Build jaxlib
39+ env :
40+ JAXCI_CLONE_MAIN_XLA : 1
41+ run : ./ci/build_artifacts.sh "jaxlib"
42+ - name : Build jax-cuda-plugin
43+ env :
44+ JAXCI_CLONE_MAIN_XLA : 1
45+ run : ./ci/build_artifacts.sh "jax-cuda-plugin"
46+ - name : Build jax-cuda-pjrt
47+ env :
48+ JAXCI_CLONE_MAIN_XLA : 1
49+ run : ./ci/build_artifacts.sh "jax-cuda-pjrt"
50+ - name : Run Bazel GPU tests locally
51+ run : ./ci/run_bazel_test_gpu_non_rbe.sh
52+ - name : Install dependencies
53+ run : |
54+ python -m pip install --upgrade pip
55+ pip install pytest
56+ pip install absl-py
57+ # pip install -U jax
58+ # pip install -U "jax[cuda12]"
59+ pip install google-benchmark
60+ - name : Run Multiprocess GPU Test
61+ run : |
62+ python -m pytest tests/multiprocess_gpu_test.py
63+
You can’t perform that action at this time.
0 commit comments