Skip to content

Commit 7d1b97f

Browse files
Create benchmarks.yml
1 parent dc928ba commit 7d1b97f

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

.github/workflows/benchmarks.yml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
name: Benchmarks
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
18+
jobs:
19+
build:
20+
strategy:
21+
matrix:
22+
runner: ["linux-x86-g2-48-l4-4gpu"]
23+
24+
runs-on: ${{ matrix.runner }}
25+
container:
26+
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
27+
28+
env:
29+
JAXCI_HERMETIC_PYTHON_VERSION: 3.11
30+
31+
steps:
32+
- uses: actions/checkout@v3
33+
# Halt for testing
34+
- name: Wait For Connection
35+
uses: google-ml-infra/actions/ci_connection@main
36+
with:
37+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
38+
- name: Build jaxlib
39+
env:
40+
JAXCI_CLONE_MAIN_XLA: 1
41+
run: ./ci/build_artifacts.sh "jaxlib"
42+
- name: Build jax-cuda-plugin
43+
env:
44+
JAXCI_CLONE_MAIN_XLA: 1
45+
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
46+
- name: Build jax-cuda-pjrt
47+
env:
48+
JAXCI_CLONE_MAIN_XLA: 1
49+
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
50+
- name: Run Bazel GPU tests locally
51+
run: ./ci/run_bazel_test_gpu_non_rbe.sh
52+
- name: Install dependencies
53+
run: |
54+
python -m pip install --upgrade pip
55+
pip install pytest
56+
pip install absl-py
57+
# pip install -U jax
58+
# pip install -U "jax[cuda12]"
59+
pip install google-benchmark
60+
- name: Run Multiprocess GPU Test
61+
run: |
62+
python -m pytest tests/multiprocess_gpu_test.py
63+

0 commit comments

Comments
 (0)