1616 - ' no'
1717 workflow_call :
1818 inputs :
19- build_jax :
20- description : " Should the jax artifact be built? (1 to enable, 0 to disable) "
19+ wheel_list :
20+ description : " A comma separated list of JAX wheels to build. E.g: jaxlib or jaxlib,jax-cuda-pjrt "
2121 type : string
2222 required : false
23- default : " 0 "
24- build_jaxlib :
25- description : " Should the jaxlib artifact be built? (1 to enable, 0 to disable) "
23+ default : " "
24+ python_list :
25+ description : " A comma separated list of Python versions to build for. E.g: 3.10 or 3.11,3.12 "
2626 type : string
2727 required : false
28- default : " 0 "
29- build_jax_cuda_plugin :
30- description : " Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable) "
28+ default : " "
29+ platform_list :
30+ description : " A comma separated list of platforms to build for. E.g: linux_x86 or linux_x86,linux_arm64,windows_x86 "
3131 type : string
3232 required : false
33- default : " 0"
34- build_jax_cuda_pjrt :
35- description : " Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)"
36- type : string
37- required : false
38- default : " 0"
33+ default : " "
3934 clone_main_xla :
4035 description : " Should latest XLA be used? (1 to enable, 0 to disable)"
4136 type : string
5853 default : " 1"
5954
6055jobs :
61- determine_artifact_matrix :
56+ determine_matrix :
6257 runs-on : " linux-x86-n2-16"
6358 container : " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
6459 outputs :
6560 artifact_matrix : ${{ steps.set-matrix.outputs.artifact_matrix }}
61+ python_matrix : ${{ steps.set-matrix.outputs.python_matrix }}
62+ platform_matrix : ${{ steps.set-matrix.outputs.platform_matrix }}
6663 defaults :
6764 run :
6865 shell : bash
@@ -74,49 +71,74 @@ jobs:
7471 halt-dispatch-input : ${{ inputs.halt-for-connection }}
7572 - id : set-matrix
7673 run : |
77- artifacts=()
7874 # Define inputs as bash variables to be able to parse them in
7975 # if conditions
8076 is_workflow_call=${{ inputs.is_workflow_call }}
81- build_jax=${{ inputs.build_jax }}
82- build_jaxlib=${{ inputs.build_jaxlib }}
83- build_jax_cuda_pjrt=${{ inputs.build_jax_cuda_pjrt }}
84- build_jax_cuda_plugin=${{ inputs.build_jax_cuda_plugin }}
77+ wheel_list=${{ inputs.wheel_list }}
78+ python_list=${{ inputs.python_list }}
79+ platform_list=${{ inputs.platform_list }}
80+
81+ # Initialize the arrays
82+ wheels=()
83+ python_versions=()
84+ platforms=()
8585
86- # Build every package if not a workflow call
86+ # Build every package for every Python version on every platform if not a workflow call
87+ # Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be
88+ # built for Windows
8789 if [[ ${is_workflow_call:-"0"} == "0" ]]; then
88- artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'")
90+ wheels=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'")
91+ python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'")
92+ platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'")
8993 else
90- if [[ ${build_jax:-"0"} == "1" ]]; then
91- artifacts+="'jax',"
92- fi
94+ # Set the Internal Field Separator to be comma
95+ IFS=,
9396
94- if [[ ${build_jaxlib:-"0"} == "1" ]]; then
95- artifacts+="'jaxlib',"
96- fi
97+ # Wheels
98+ for wheel in $wheel_list; do
99+ wheels+="'$wheel',"
100+ done
97101
98- if [[ ${build_jax_cuda_pjrt:-"0"} == "1" ]]; then
99- artifacts+="'jax-cuda-pjrt',"
100- fi
102+ # Python versions
103+ for python_version in $python_list; do
104+ python_versions+="'$python_version',"
105+ done
101106
102- if [[ ${build_jax_cuda_plugin:-"0"} == "1" ]]; then
103- artifacts+="'jax-cuda-plugin'"
104- fi
107+ # Platforms
108+ for platform in $platform_list; do
109+ if [[ $platform == "linux_x86" ]]; then
110+ platforms+="'linux-x86-n2-16',"
111+ elif [[ $platform == "linux_arm64" ]]; then
112+ platforms+="'linux-arm64-t2a-48',"
113+ elif [[ $platform == "windows_x86" ]]; then
114+ platforms+="'windows-x86-n2-64',"
115+ else
116+ echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86"
117+ exit 1
118+ fi
119+ done
105120 fi
106- echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT
121+
122+ echo "artifact_matrix=[${wheels[@]}]" >> $GITHUB_OUTPUT
123+ echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT
124+ echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT
125+
126+ echo "Artifacts: $artifact_matrix"
127+ echo "Python versions: $python_matrix"
128+ echo "Platforms: $platform_matrix"
107129
108130 build_artifacts :
109- needs : determine_artifact_matrix
131+ needs : determine_matrix
110132 continue-on-error : true
111133 defaults :
112134 run :
113135 # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
114136 shell : bash
115137 strategy :
116138 matrix :
117- runner : ["linux-x86-n2-16"] # , "linux-arm64-t2a-48", "windows-x86-n2-64"]
118- artifact : ${{ fromJSON(needs.determine_artifact_matrix .outputs.artifact_matrix) }}
119- python : ["3.10", "3.11", "3.12", "3.13"]
139+ runner : ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }}
140+ artifact : ${{ fromJSON(needs.determine_matrix .outputs.artifact_matrix) }}
141+ python : ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }}
120142 # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each
121143 # Python version.
122144 exclude :
0 commit comments