@@ -57,24 +57,31 @@ jobs:
5757 runs-on : " linux-x86-n2-16"
5858 container : " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
5959 outputs :
60- build_matrix : ${{ steps.set-matrix.outputs.build_matrix }}
60+ artifact_matrix : ${{ steps.set-matrix.outputs.artifact_matrix }}
6161 steps :
6262 - id : set-matrix
6363 run : |
64- matrix='[]'
65- if [[ ${{ inputs.build_jax }} == "1" ]]; then
66- matrix='["jax"]'
67- if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then
68- matrix='["jax", "jaxlib"]'
64+ artifacts=()
65+ if [[ ${{ github.event }} == "pull_request" ]];
66+ artifacts = ("jax" ", jaxlib" ", jax-cuda-pjrt" ", jax-cuda-plugin")
67+ else
68+ if [[ ${{ inputs.build_jax }} == "1" ]]; then
69+ artifacts+="jax"
70+ fi
71+
72+ if [[ ${{ inputs.build_jaxlib }} == "1" ]]; then
73+ artifacts+=", jaxlib"
74+ fi
75+
6976 if [[ ${{ inputs.build_jax_cuda_pjrt }} == "1" ]]; then
70- matrix='["jax", "jaxlib", "jax-cuda-pjrt"]'
71- if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then
72- matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]'
73- fi
77+ artifacts+=", jax-cuda-pjrt"
78+ fi
79+
80+ if [[ ${{ inputs.build_jax_cuda_plugin }} == "1" ]]; then
81+ artifacts+=", jax-cuda-plugin"
7482 fi
7583 fi
76- fi
77- echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT
84+ echo "artifact_matrix='[${artifacts[@]}]'" >> $GITHUB_OUTPUT
7885
7986 build_artifacts :
8087 needs : determine_matrix
8693 strategy :
8794 matrix :
8895 runner : ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
89- artifact : ${{ fromJSON(needs.determine_matrix.outputs.build_matrix ) }}
96+ artifact : ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix ) }}
9097 python : ["3.10"] # , "3.11", "3.12"]
9198 # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
9299 # Python version.
0 commit comments