Skip to content

Commit d4935bb

Browse files
committed
Sync to upstream and add new workflows for running continuous/nightly testing
1 parent 1ef1508 commit d4935bb

18 files changed

+462
-318
lines changed
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: CI - Bazel CPU tests (RBE)
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -14,18 +14,19 @@ on:
1414
options:
1515
- 'yes'
1616
- 'no'
17-
pull_request:
18-
branches:
19-
- main
2017

2118
concurrency:
2219
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
2320
cancel-in-progress: true
2421

2522
jobs:
2623
run_tests:
27-
if: github.event.repository.fork == false
24+
defaults:
25+
run:
26+
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
27+
shell: bash
2828
strategy:
29+
fail-fast: false # don't cancel all jobs on failure
2930
matrix:
3031
runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-c4a-16"]
3132
enable-x_64: [1, 0]
@@ -35,18 +36,17 @@ jobs:
3536
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
3637
(contains(matrix.runner, 'windows-x86') && null) }}
3738

38-
3939
env:
4040
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
4141
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}
4242

4343
name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"
4444

4545
steps:
46-
- uses: actions/checkout@v3
47-
- name: Wait For Connection
48-
uses: google-ml-infra/actions/ci_connection@main
49-
with:
50-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
51-
- name: Run Bazel CPU Tests with RBE
52-
run: ./ci/run_bazel_test_cpu_rbe.sh
46+
- uses: actions/checkout@v3
47+
- name: Wait For Connection
48+
uses: google-ml-infra/actions/ci_connection@main
49+
with:
50+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
51+
- name: Run Bazel CPU Tests with RBE
52+
run: ./ci/run_bazel_test_cpu_rbe.sh
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
name: CI - Bazel CUDA (non-RBE)
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
halt-for-connection:
7+
description: 'Should this workflow run wait for a remote connection?'
8+
type: choice
9+
required: true
10+
default: 'no'
11+
options:
12+
- 'yes'
13+
- 'no'
14+
workflow_call:
15+
inputs:
16+
runner:
17+
description: "Which runner should the workflow run on?"
18+
type: string
19+
required: true
20+
default: "linux-x86-n2-16"
21+
python:
22+
description: "Which python version to test?"
23+
type: string
24+
required: true
25+
default: "3.12"
26+
enable-x64:
27+
description: "Should x64 mode be enabled?"
28+
type: string
29+
required: true
30+
default: "0"
31+
gcs_download_uri:
32+
description: "GCS location URI from where the artifacts should be downloaded"
33+
required: true
34+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
35+
type: string
36+
37+
38+
jobs:
39+
run-tests:
40+
runs-on: ${{ inputs.runner }}
41+
42+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-rbe:latest"
43+
44+
env:
45+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
46+
JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}"
47+
48+
name: "Bazel single accelerator and multi-accelerator GPU tests (Non RBE, ${{ inputs.runner }}, Python 3.11, x64=${{ inputs.enable-x64 }})"
49+
50+
steps:
51+
- uses: actions/checkout@v3
52+
# Halt for testing
53+
- name: Wait For Connection
54+
uses: google-ml-infra/actions/ci_connection@main
55+
with:
56+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
57+
- name: Set PLATFORM env var for use in artifact download URL
58+
run: |
59+
os=$(uname -s | awk '{print tolower($0)}')
60+
arch=$(uname -m)
61+
62+
# Get the major and minor version of Python.
63+
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
64+
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
65+
66+
echo "OS=${os}" >> $GITHUB_ENV
67+
echo "ARCH=${arch}" >> $GITHUB_ENV
68+
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
69+
- name: Download the wheel artifacts from GCS
70+
run: >-
71+
mkdir -p $(pwd)/dist &&
72+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
73+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
74+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/
75+
- name: Run Bazel tests
76+
run: ./ci/run_bazel_test_gpu_non_rbe.sh

.github/workflows/bazel_gpu_rbe.yml renamed to .github/workflows/bazel_cuda_rbe.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
name: CI - Bazel GPU tests (RBE)
1+
name: CI - Bazel CUDA tests (RBE)
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -24,7 +24,6 @@ concurrency:
2424

2525
jobs:
2626
run_tests:
27-
if: github.event.repository.fork == false
2827
strategy:
2928
matrix:
3029
runner: ["linux-x86-n2-16"]

.github/workflows/bazel_gpu_non_rbe.yml

Lines changed: 0 additions & 51 deletions
This file was deleted.

.github/workflows/build_artifacts.yml

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ on:
5858
required: false
5959
default: false
6060
type: boolean
61-
upload_destination_prefix:
62-
description: "GCS location prefix to where the artifacts should be uploaded"
61+
gcs_upload_uri:
62+
description: "GCS location URI to where the artifacts should be uploaded"
6363
required: false
64-
default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
64+
default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
6565
type: string
6666

6767
jobs:
@@ -95,25 +95,12 @@ jobs:
9595
halt-dispatch-input: ${{ inputs.halt-for-connection }}
9696
- name: Build ${{ inputs.artifact }}
9797
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
98-
- name: Set PLATFORM env var for use in upload destination
99-
run: |
100-
os=$(uname -s | awk '{print tolower($0)}')
101-
arch=$(uname -m)
102-
103-
# Adjust name for Windows
104-
if [[ $os =~ "msys_nt" ]]; then
105-
os="windows"
106-
fi
107-
108-
echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV
10998
- name: Upload artifacts to GCS bucket (non-Windows)
11099
if: >-
111100
${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }}
112-
run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/
101+
run: gsutil -m cp -r $(pwd)/dist/*.whl "${{ inputs.gcs_upload_uri }}"/
113102
- name: Upload artifacts to GCS bucket (Windows)
114103
if: >-
115104
${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }}
116105
shell: cmd
117-
run: gsutil -m cp -r dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/
118-
119-
106+
run: gsutil -m cp -r dist/*.whl "${{ inputs.gcs_upload_uri }}"/

0 commit comments

Comments
 (0)