Skip to content

Commit 6eca13c

Browse files
committed
change how artifact matrix is constructed
1 parent e4b45d1 commit 6eca13c

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

.github/workflows/build_artifacts.yml

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,31 @@ jobs:
5757
runs-on: "linux-x86-n2-16"
5858
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
5959
outputs:
60-
build_matrix: ${{ steps.set-matrix.outputs.build_matrix }}
60+
artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }}
6161
steps:
6262
- id: set-matrix
6363
run: |
64-
matrix='[]'
65-
if [[ ${{ inputs.build_jax }} == "1" ]]; then
66-
matrix='["jax"]'
67-
if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then
68-
matrix='["jax", "jaxlib"]'
64+
artifacts=()
65+
if [[ ${{ github.event }} == "pull_request" ]];
66+
artifacts = ("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin")
67+
else
68+
if [[ ${{ inputs.build_jax }} == "1" ]]; then
69+
artifacts+="jax"
70+
fi
71+
72+
if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then
73+
artifacts+=", jaxlib"
74+
fi
75+
6976
if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]]; then
70-
matrix='["jax", "jaxlib", "jax-cuda-pjrt"]'
71-
if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then
72-
matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]'
73-
fi
77+
artifacts+=", jax-cuda-pjrt"
78+
fi
79+
80+
if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then
81+
artifacts+=", jax-cuda-plugin"
7482
fi
7583
fi
76-
fi
77-
echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT
84+
echo "artifact_matrix='[${artifacts[@]}]'" >> $GITHUB_OUTPUT
7885
7986
build_artifacts:
8087
needs: determine_matrix
@@ -86,7 +93,7 @@ jobs:
8693
strategy:
8794
matrix:
8895
runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
89-
artifact: ${{ fromJSON(needs.determine_matrix.outputs.build_matrix) }}
96+
artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }}
9097
python: ["3.10"] #, "3.11", "3.12"]
9198
# jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
9299
# Python version.

0 commit comments

Comments
 (0)