Skip to content

Commit 53f20d9

Browse files
authored
Merge pull request #17 from google-ml-infra/srnitin/task-jax-ci-rework
Add new CI scripts as part of JAX CI Rework
2 parents cab2ff6 + 165bb00 commit 53f20d9

26 files changed

+1927
-501
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Run Bazel CPU tests (RBE)
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
18+
jobs:
19+
run_bazel_rbe_cpu_tests:
20+
continue-on-error: true
21+
defaults:
22+
run:
23+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
24+
shell: bash
25+
strategy:
26+
matrix:
27+
runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
28+
29+
runs-on: ${{ matrix.runner }}
30+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12') ||
31+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
32+
(contains(matrix.runner, 'windows-x86') && null) }}
33+
34+
env:
35+
JAXCI_CLONE_MAIN_XLA: 1
36+
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
37+
38+
steps:
39+
- uses: actions/checkout@v3
40+
# Halt for testing
41+
- name: Wait For Connection
42+
uses: google-ml-infra/actions/ci_connection@main
43+
with:
44+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
45+
- name: Run Bazel CPU Tests
46+
run: ./ci/run_bazel_test_cpu_rbe.sh
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Run Bazel GPU tests (non RBE)
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
18+
jobs:
19+
build:
20+
strategy:
21+
matrix:
22+
runner: ["linux-x86-g2-48-l4-4gpu"]
23+
24+
runs-on: ${{ matrix.runner }}
25+
container:
26+
image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
27+
28+
env:
29+
JAXCI_HERMETIC_PYTHON_VERSION: 3.11
30+
31+
steps:
32+
- uses: actions/checkout@v3
33+
# Halt for testing
34+
- name: Wait For Connection
35+
uses: google-ml-infra/actions/ci_connection@main
36+
with:
37+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
38+
- name: Build jaxlib
39+
env:
40+
JAXCI_CLONE_MAIN_XLA: 1
41+
run: ./ci/build_artifacts.sh "jaxlib"
42+
- name: Build jax-cuda-plugin
43+
env:
44+
JAXCI_CLONE_MAIN_XLA: 1
45+
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
46+
- name: Build jax-cuda-pjrt
47+
env:
48+
JAXCI_CLONE_MAIN_XLA: 1
49+
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
50+
- name: Run Bazel GPU tests locally
51+
run: ./ci/run_bazel_test_gpu_non_rbe.sh
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Run Bazel GPU tests (RBE)
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
18+
jobs:
19+
build:
20+
strategy:
21+
matrix:
22+
runner: ["linux-x86-n2-16"]
23+
24+
runs-on: ${{ matrix.runner }}
25+
container:
26+
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
27+
28+
env:
29+
JAXCI_CLONE_MAIN_XLA: 1
30+
JAXCI_HERMETIC_PYTHON_VERSION: 3.12
31+
32+
steps:
33+
- uses: actions/checkout@v3
34+
# Halt for testing
35+
- name: Wait For Connection
36+
uses: google-ml-infra/actions/ci_connection@main
37+
with:
38+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
39+
- name: Run Bazel GPU tests using RBE
40+
run: ./ci/run_bazel_test_gpu_rbe.sh
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: Build JAX Artifacts
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
workflow_call:
18+
19+
jobs:
20+
build:
21+
continue-on-error: true
22+
defaults:
23+
run:
24+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
25+
shell: bash
26+
strategy:
27+
matrix:
28+
runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
29+
artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]
30+
python: ["3.10", "3.11", "3.12"]
31+
# jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
32+
# Python version.
33+
exclude:
34+
# Pure Python packages do not need to be built for each Python version.
35+
- artifact: "jax-cuda-pjrt"
36+
python: "3.10"
37+
- artifact: "jax-cuda-pjrt"
38+
python: "3.11"
39+
- artifact: "jax"
40+
python: "3.10"
41+
- artifact: "jax"
42+
python: "3.11"
43+
# jax is a pure Python package so it does not need to be built on multiple platforms.
44+
- artifact: "jax"
45+
runner: "windows-x86-n2-64"
46+
- artifact: "jax"
47+
runner: "linux-arm64-t2a-16"
48+
# jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows.
49+
- artifact: "jax-cuda-plugin"
50+
runner: "windows-x86-n2-64"
51+
- artifact: "jax-cuda-pjrt"
52+
runner: "windows-x86-n2-64"
53+
54+
runs-on: ${{ matrix.runner }}
55+
56+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
57+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
58+
(contains(matrix.runner, 'windows-x86') && null) }}
59+
60+
env:
61+
# Do not run Docker container for Linux runners. Linux runners already run in a Docker container.
62+
JAXCI_RUN_DOCKER_CONTAINER: 0
63+
64+
steps:
65+
- uses: actions/checkout@v3
66+
# Halt for testing
67+
- name: Wait For Connection
68+
uses: google-ml-infra/actions/ci_connection@main
69+
with:
70+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
71+
- name: Build ${{ matrix.artifact }}
72+
env:
73+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
74+
run: ./ci/build_artifacts.sh "${{ matrix.artifact }}"

.github/workflows/pytest_cpu.yml

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
name: Run Pytest CPU tests
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
18+
jobs:
19+
build:
20+
continue-on-error: true
21+
defaults:
22+
run:
23+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
24+
shell: bash
25+
strategy:
26+
matrix:
27+
runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"]
28+
python: ["3.10"]
29+
30+
runs-on: ${{ matrix.runner }}
31+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
32+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
33+
(contains(matrix.runner, 'windows-x86') && null) }}
34+
35+
env:
36+
JAXCI_CLONE_MAIN_XLA: 1
37+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
38+
39+
steps:
40+
- uses: actions/checkout@v3
41+
# Halt for testing
42+
- name: Wait For Connection
43+
uses: google-ml-infra/actions/ci_connection@main
44+
with:
45+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
46+
- name: Build jaxlib
47+
run: ./ci/build_artifacts.sh "jaxlib"
48+
- name: Install pytest
49+
env:
50+
JAXCI_PYTHON: python${{ matrix.python }}
51+
run: $JAXCI_PYTHON -m pip install pytest
52+
- name: Install dependencies
53+
env:
54+
JAXCI_PYTHON: python${{ matrix.python }}
55+
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
56+
- name: Run Pytest CPU tests
57+
run: ./ci/run_pytest_cpu.sh
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: Run Pytest CPU tests (resuable workflow)
2+
3+
on:
4+
# pull_request:
5+
# branches:
6+
# - main
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
17+
jobs:
18+
build_jaxlib_artifacts:
19+
uses: ./.github/workflows/build_artifacts.yml
20+
21+
run_pytest:
22+
needs: build_jaxlib_artifacts
23+
continue-on-error: true
24+
defaults:
25+
run:
26+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
27+
shell: bash
28+
strategy:
29+
matrix:
30+
runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"]
31+
python: ["3.10"]
32+
33+
runs-on: ${{ matrix.runner }}
34+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
35+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
36+
(contains(matrix.runner, 'windows-x86') && null) }}
37+
38+
env:
39+
JAXCI_CLONE_MAIN_XLA: 1
40+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
41+
42+
steps:
43+
- uses: actions/checkout@v3
44+
# Halt for testing
45+
- name: Wait For Connection
46+
uses: google-ml-infra/actions/ci_connection@main
47+
with:
48+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
49+
- name: Install pytest
50+
env:
51+
JAXCI_PYTHON: python${{ matrix.python }}
52+
run: $JAXCI_PYTHON -m pip install pytest
53+
- name: Install dependencies
54+
env:
55+
JAXCI_PYTHON: python${{ matrix.python }}
56+
run: $JAXCI_PYTHON -m pip install -r build/requirements.in
57+
- name: Run Pytest CPU tests
58+
run: ./ci/run_pytest.sh "ci/envs/run_tests/pytest_cpu"

0 commit comments

Comments
 (0)