Skip to content

Commit b4f7c8e

Browse files
committed
experiment with reusable workflows
1 parent 84f072e commit b4f7c8e

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

.github/workflows/build_artifacts.yml

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,42 @@ on:
1515
- 'yes'
1616
- 'no'
1717
workflow_call:
18+
inputs:
19+
build_jax:
20+
description: "Should the jax artifact be built?"
21+
required: true
22+
default: true
23+
type: boolean
24+
build_jaxlib:
25+
description: "Should the jaxlib artifact be built?"
26+
required: true
27+
default: true
28+
type: boolean
29+
build_jax_cuda_plugin:
30+
description: "Should the jax-cuda-plugin artifact be built?"
31+
required: true
32+
default: true
33+
type: boolean
34+
build_jax_cuda_pjrt:
35+
description: "Should the jax-cuda-pjrt artifact be built?"
36+
required: true
37+
default: true
38+
type: boolean
39+
clone_main_xla:
40+
description: "Should latest XLA be used? (1 to enable, 0 to disable)"
41+
type: string
42+
required: true
43+
default: "0"
44+
upload_artifacts:
45+
description: "Should the artifacts be uploaded to a GCS bucket?"
46+
required: true
47+
default: false
48+
type: boolean
49+
upload_destination:
50+
description: "GCS location to where the artifacts should be uploaded"
51+
required: true
52+
default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
53+
type: string
1854

1955
jobs:
2056
build:
@@ -27,7 +63,7 @@ jobs:
2763
matrix:
2864
runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
2965
artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]
30-
python: ["3.10", "3.11", "3.12"]
66+
python: ["3.10"] #, "3.11", "3.12"]
3167
# jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
3268
# Python version.
3369
exclude:
@@ -58,8 +94,8 @@ jobs:
5894
(contains(matrix.runner, 'windows-x86') && null) }}
5995

6096
env:
61-
# Do not run Docker container for Linux runners. Linux runners already run in a Docker container.
62-
JAXCI_RUN_DOCKER_CONTAINER: 0
97+
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
98+
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
6399

64100
steps:
65101
- uses: actions/checkout@v3
@@ -68,7 +104,19 @@ jobs:
68104
uses: google-ml-infra/actions/ci_connection@main
69105
with:
70106
halt-dispatch-input: ${{ inputs.halt-for-connection }}
71-
- name: Build ${{ matrix.artifact }}
72-
env:
73-
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}"
74-
run: ./ci/build_artifacts.sh "${{ matrix.artifact }}"
107+
- name: Build jax
108+
if: inputs.build_jax && matrix.artifact == 'jax'
109+
run: ./ci/build_artifacts.sh "jax"
110+
- name: Build jaxlib
111+
if: inputs.build_jaxlib && matrix.artifact == 'jaxlib'
112+
run: ./ci/build_artifacts.sh "jaxlib"
113+
- name: Build jax-cuda-plugin
114+
if: inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin'
115+
run: ./ci/build_artifacts.sh "jax-cuda-plugin"
116+
- name: Build jax-cuda-pjrt
117+
if: inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt'
118+
run: ./ci/build_artifacts.sh "jax-cuda-pjrt"
119+
- name: Upload artifacts to GCS bucket
120+
if: inputs.upload_artifacts
121+
run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"
122+

.github/workflows/pytest_cpu_reuse.yml

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Run Pytest CPU tests (resuable workflow)
22

33
on:
4-
# pull_request:
5-
# branches:
6-
# - main
4+
pull_request:
5+
branches:
6+
- main
77
workflow_dispatch:
88
inputs:
99
halt-for-connection:
@@ -15,28 +15,36 @@ on:
1515
- 'yes'
1616

1717
jobs:
18-
build_jaxlib_artifacts:
18+
build_jaxlib_artifact:
19+
name: "Build the jaxlib aritfact using latest XLA"
1920
uses: ./.github/workflows/build_artifacts.yml
21+
with:
22+
build_jax: false
23+
build_jaxlib: true
24+
build_jax_cuda_plugin: false
25+
build_jax_cuda_pjrt: false
26+
clone_main_xla: 1
27+
upload_artifacts: true
28+
upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
2029

2130
run_pytest:
22-
needs: build_jaxlib_artifacts
31+
name: "Run CPU tests with Pytest"
32+
needs: build_jaxlib_artifact
2333
continue-on-error: true
2434
defaults:
2535
run:
2636
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
2737
shell: bash
2838
strategy:
2939
matrix:
30-
runner: ["windows-x86-n2-64", "linux-x86-n2-64", "linux-arm64-t2a-48"]
40+
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"]
3141
python: ["3.10"]
3242

3343
runs-on: ${{ matrix.runner }}
3444
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
35-
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
36-
(contains(matrix.runner, 'windows-x86') && null) }}
45+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
3746

3847
env:
39-
JAXCI_CLONE_MAIN_XLA: 1
4048
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
4149

4250
steps:
@@ -46,6 +54,8 @@ jobs:
4654
uses: google-ml-infra/actions/ci_connection@main
4755
with:
4856
halt-dispatch-input: ${{ inputs.halt-for-connection }}
57+
- name: Download the artifacts built in the "build_artifacts" job
58+
run: mkdir -p $(pwd)/dist && ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r gs://general-ml-ci-transient/jax-github-actions/"${{ github.workflow }}"/${{ github.run_number }}/${{ github.run_attempt }} $(pwd)/dist
4959
- name: Install pytest
5060
env:
5161
JAXCI_PYTHON: python${{ matrix.python }}

0 commit comments

Comments
 (0)