Test CI scripts and workflows #267
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Build JAX Artifacts | |
| on: | |
| pull_request: | |
| branches: | |
| - main | |
| workflow_dispatch: | |
| inputs: | |
| halt-for-connection: | |
| description: 'Should this workflow run wait for a remote connection?' | |
| type: choice | |
| required: true | |
| default: 'no' | |
| options: | |
| - 'yes' | |
| - 'no' | |
| workflow_call: | |
| inputs: | |
| wheel_list: | |
| description: "A comma separated list of JAX wheels to build. E.g: jaxlib or jaxlib,jax-cuda-pjrt" | |
| type: string | |
| required: false | |
| default: "" | |
| python_list: | |
| description: "A comma separated list of Python versions to build for. E.g: 3.10 or 3.11,3.12" | |
| type: string | |
| required: false | |
| default: "" | |
| platform_list: | |
| description: "A comma separated list of platforms to build for. E.g: linux_x86 or linux_x86,linux_arm64,windows_x86" | |
| type: string | |
| required: false | |
| default: "" | |
| clone_main_xla: | |
| description: "Should latest XLA be used? (1 to enable, 0 to disable)" | |
| type: string | |
| required: false | |
| default: "0" | |
| upload_artifacts: | |
| description: "Should the artifacts be uploaded to a GCS bucket?" | |
| required: false | |
| default: false | |
| type: boolean | |
| upload_destination_prefix: | |
| description: "GCS location prefix to where the artifacts should be uploaded" | |
| required: false | |
| default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
| type: string | |
| is_workflow_call: | |
| description: "Metadata variable to know whether a workflow call was made" | |
| type: string | |
| required: false | |
| default: "1" | |
| jobs: | |
| determine_matrix: | |
| runs-on: "linux-x86-n2-16" | |
| container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" | |
| outputs: | |
| artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} | |
| python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} | |
| platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} | |
| defaults: | |
| run: | |
| shell: bash | |
| steps: | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - name: "Determine the matrix" | |
| id: set-matrix | |
| run: | | |
| # Define inputs as bash variables to be able to parse them in | |
| # if conditions | |
| is_workflow_call=${{ inputs.is_workflow_call }} | |
| wheel_list=${{ inputs.wheel_list }} | |
| python_list=${{ inputs.python_list }} | |
| platform_list=${{ inputs.platform_list }} | |
| # Initialize the arrays | |
| wheels=() | |
| python_versions=() | |
| platforms=() | |
| # Build every package for every Python version on every platform if not a workflow call | |
| # Packages that are not supported on a platform won't be built. E.g. CUDA packages won't be | |
| # built for Windows | |
| if [[ ${is_workflow_call:-"0"} == "0" ]]; then | |
| wheels=("'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") | |
| python_versions=("'3.10'" ", '3.11'" ", '3.12'", ", '3.13'") | |
| platforms=("'linux-x86-n2-16'" ", 'linux-arm64-t2a-48'" ", 'windows-x86-n2-64'") | |
| else | |
| # Set the Internal Field Separator to be comma | |
| IFS=, | |
| # Wheels | |
| for wheel in $wheel_list; do | |
| wheels+="'$wheel'," | |
| done | |
| # Python versions | |
| for python_version in $python_list; do | |
| python_versions+="'$python_version'," | |
| done | |
| # Platforms | |
| for platform in $platform_list; do | |
| if [[ $platform == "linux_x86" ]]; then | |
| platforms+="'linux-x86-n2-16'," | |
| elif [[ $platform == "linux_arm64" ]]; then | |
| platforms+="'linux-arm64-t2a-48'," | |
| elif [[ $platform == "windows_x86" ]]; then | |
| platforms+="'windows-x86-n2-64'," | |
| else | |
| echo "Incorrect platform provided. Valid options are: linux_x86, linux_arm64, windows_x86" | |
| exit 1 | |
| fi | |
| done | |
| fi | |
| echo "artifact_matrix=[${wheels[@]}]" >> $GITHUB_OUTPUT | |
| echo "python_matrix=[${python_versions[@]}]" >> $GITHUB_OUTPUT | |
| echo "platform_matrix=[${platforms[@]}]" >> $GITHUB_OUTPUT | |
| echo "Artifacts: ${wheels[@]}" | |
| echo "Python versions:${python_versions[@]}" | |
| echo "Platforms: ${platforms[@]}" | |
| build_artifacts: | |
| needs: determine_matrix | |
| defaults: | |
| run: | |
| # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | |
| shell: bash | |
| strategy: | |
| fail-fast: false # don't cancel all jobs on failure | |
| matrix: | |
| runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} | |
| artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} | |
| python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} | |
| exclude: | |
| # jax-cuda-pjrt does not need to be built for every Python but excluding it here for | |
| # every but one Python version causes issues when a workflow call is made to this file | |
| # requesting a build for an exlcuded Python version (see pytest_gpu.yaml) | |
| # | |
| # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. | |
| - artifact: "jax-cuda-plugin" | |
| runner: "windows-x86-n2-64" | |
| - artifact: "jax-cuda-pjrt" | |
| runner: "windows-x86-n2-64" | |
| runs-on: ${{ matrix.runner }} | |
| container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
| (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
| (contains(matrix.runner, 'windows-x86') && null) }} | |
| env: | |
| JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" | |
| JAXCI_BUILD_ARTIFACT_WITH_RBE: 1 | |
| JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | |
| steps: | |
| - uses: actions/checkout@v3 | |
| - name: Enable RBE on platforms where its supported | |
| run: | | |
| os=$(uname -s | awk '{print tolower($0)}') | |
| arch=$(uname -m) | |
| if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then | |
| echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | |
| fi | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - name: Build ${{ matrix.artifact }} | |
| run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" | |
| - name: Set PLATFORM env var for use in upload destination | |
| run: | | |
| os=$(uname -s | awk '{print tolower($0)}') | |
| arch=$(uname -m) | |
| # Adjust name for Windows | |
| if [[ $os =~ "msys_nt" ]]; then | |
| os="windows" | |
| fi | |
| echo "PLATFORM=${os}_${arch}" >> $GITHUB_ENV | |
| - name: Upload artifacts to GCS bucket | |
| if: inputs.upload_artifacts | |
| run: gsutil -m cp -r $(pwd)/dist/*.whl gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination_prefix }}"/$PLATFORM/python${JAXCI_HERMETIC_PYTHON_VERSION}/ | |