Skip to content

Commit 28a2652

Browse files
committed
simplify workflow calls
1 parent 6f1995d commit 28a2652

File tree

2 files changed

+65
-41
lines changed

2 files changed

+65
-41
lines changed

.github/workflows/build_artifacts.yml

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,21 @@ on:
1616
- 'no'
1717
workflow_call:
1818
inputs:
19-
build_jax:
20-
description: "Should the jax artifact be built? (1 to enable, 0 to disable)"
19+
wheel_list:
20+
description: "A comma separated list of JAX wheels to build. E.g: jaxlib or jaxlib,jax-cuda-pjrt"
2121
type: string
2222
required: false
23-
default: "0"
24-
build_jaxlib:
25-
description: "Should the jaxlib artifact be built? (1 to enable, 0 to disable)"
23+
default: ""
24+
python_list:
25+
description: "A comma separated list of Python versions to build for. E.g: 3.10 or 3.11,3.12"
2626
type: string
2727
required: false
28-
default: "0"
29-
build_jax_cuda_plugin:
30-
description: "Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)"
28+
default: ""
29+
platform_list:
30+
description: "A comma separated list of platforms to build for. E.g: linux_x86 or linux_x86,linux_arm64,windows_x86"
3131
type: string
3232
required: false
33-
default: "0"
34-
build_jax_cuda_pjrt:
35-
description: "Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)"
36-
type: string
37-
required: false
38-
default: "0"
33+
default: ""
3934
clone_main_xla:
4035
description: "Should latest XLA be used? (1 to enable, 0 to disable)"
4136
type: string
@@ -58,11 +53,13 @@ on:
5853
default: "1"
5954

6055
jobs:
61-
determine_artifact_matrix:
56+
determine_matrix:
6257
runs-on: "linux-x86-n2-16"
6358
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
6459
outputs:
6560
artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }}
61+
python_matrix: ${{ steps.set-matrix.outputs.python_matrix }}
62+
platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }}
6663
defaults:
6764
run:
6865
shell: bash
@@ -74,49 +71,74 @@ jobs:
7471
halt-dispatch-input: ${{ inputs.halt-for-connection }}
7572
- id: set-matrix
7673
run: |
77-
artifacts=()
7874
# Define inputs as bash variables to be able to parse them in
7975
# if conditions
8076
is_workflow_call=${{ inputs.is_workflow_call }}
81-
build_jax=${{ inputs.build_jax }}
82-
build_jaxlib=${{ inputs.build_jaxlib }}
83-
build_jax_cuda_pjrt=${{ inputs.build_jax_cuda_pjrt }}
84-
build_jax_cuda_plugin=${{ inputs.build_jax_cuda_plugin }}
77+
wheel_list=${{ inputs.wheel_list }}
78+
python_list=${{ inputs.python_list }}
79+
platform_list=${{ inputs.platform_list }}
80+
81+
# Initialize the arrays
82+
wheels=()
83+
python_versions=()
84+
platforms=()
8585
86-
# Build every package if not a workflow call
86+
# Build every package for every Python version on every platform if not a workflow call
87+
# Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be
88+
# built for Windows
8789
if [[ ${is_workflow_call:-"0"} == "0" ]]; then
88-
artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'")
90+
wheels=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'")
91+
python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'")
92+
platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'")
8993
else
90-
if [[ ${build_jax:-"0"} == "1" ]]; then
91-
artifacts+="'jax',"
92-
fi
94+
# Set the Internal Field Separator to be comma
95+
IFS=,
9396
94-
if [[ ${build_jaxlib:-"0"} == "1" ]]; then
95-
artifacts+="'jaxlib',"
96-
fi
97+
# Wheels
98+
for wheel in $wheel_list; do
99+
wheels+="'$wheel',"
100+
done
97101
98-
if [[ ${build_jax_cuda_pjrt:-"0"} == "1" ]]; then
99-
artifacts+="'jax-cuda-pjrt',"
100-
fi
102+
# Python versions
103+
for python_version in $python_list; do
104+
python_versions+="'$python_version',"
105+
done
101106
102-
if [[ ${build_jax_cuda_plugin:-"0"} == "1" ]]; then
103-
artifacts+="'jax-cuda-plugin'"
104-
fi
107+
# Platforms
108+
for platform in $platform_list; do
109+
if [[ $platform == "linux_x86" ]]; then
110+
platforms+="'linux-x86-n2-16',"
111+
elif [[ $platform == "linux_arm64" ]]; then
112+
platforms+="'linux-arm64-t2a-48',"
113+
elif [[ $platform == "windows_x86" ]]; then
114+
platforms+="'windows-x86-n2-64',"
115+
else
116+
echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86"
117+
exit 1
118+
fi
119+
done
105120
fi
106-
echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT
121+
122+
echo "artifact_matrix=[${wheels[@]}]" >> $GITHUB_OUTPUT
123+
echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT
124+
echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT
125+
126+
echo "Artifacts: $artifact_matrix"
127+
echo "Python versions: $python_matrix"
128+
echo "Platforms: $platform_matrix"
107129
108130
build_artifacts:
109-
needs: determine_artifact_matrix
131+
needs: determine_matrix
110132
continue-on-error: true
111133
defaults:
112134
run:
113135
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
114136
shell: bash
115137
strategy:
116138
matrix:
117-
runner: ["linux-x86-n2-16"] #, "linux-arm64-t2a-48", "windows-x86-n2-64"]
118-
artifact: ${{ fromJSON(needs.determine_artifact_matrix.outputs.artifact_matrix) }}
119-
python: ["3.10", "3.11", "3.12", "3.13"]
139+
runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }}
140+
artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }}
141+
python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }}
120142
# jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
121143
# Python version.
122144
exclude:

.github/workflows/pytest_cpu_reuse.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ jobs:
1919
name: "Build the jaxlib aritfact using latest XLA"
2020
uses: ./.github/workflows/build_artifacts.yml
2121
with:
22-
build_jaxlib: 1
22+
wheel_list: "jaxlib"
23+
python_list: "3.10"
24+
platform_list: "linux_x86,linux_arm64"
2325
clone_main_xla: 1
2426
upload_artifacts: true
2527
upload_destination: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
@@ -34,7 +36,7 @@ jobs:
3436
shell: bash
3537
strategy:
3638
matrix:
37-
runner: ["linux-x86-n2-64"] #, "linux-arm64-t2a-48"]
39+
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"]
3840
python: ["3.10"]
3941

4042
runs-on: ${{ matrix.runner }}

0 commit comments

Comments
 (0)