Test CI scripts and workflows #238
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Build JAX Artifacts | |
| on: | |
| pull_request: | |
| branches: | |
| - main | |
| workflow_dispatch: | |
| inputs: | |
| halt-for-connection: | |
| description: 'Should this workflow run wait for a remote connection?' | |
| type: choice | |
| required: true | |
| default: 'no' | |
| options: | |
| - 'yes' | |
| - 'no' | |
| workflow_call: | |
| inputs: | |
| build_jax: | |
| description: "Should the jax artifact be built? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "1" | |
| build_jaxlib: | |
| description: "Should the jaxlib artifact be built? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "1" | |
| build_jax_cuda_plugin: | |
| description: "Should the jax-cuda-plugin artifact be built? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "1" | |
| build_jax_cuda_pjrt: | |
| description: "Should the jax-cuda-pjrt artifact be built? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "1" | |
| clone_main_xla: | |
| description: "Should latest XLA be used? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "0" | |
| upload_artifacts: | |
| description: "Should the artifacts be uploaded to a GCS bucket?" | |
| required: true | |
| default: false | |
| type: boolean | |
| upload_destination: | |
| description: "GCS location to where the artifacts should be uploaded" | |
| required: true | |
| default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
| type: string | |
| is_workflow_call: | |
| description: "Metadata variable to know whether a workflow call was made" | |
| type: string | |
| required: true | |
| default: "1" | |
| jobs: | |
| determine_matrix: | |
| runs-on: "linux-x86-n2-16" | |
| container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" | |
| outputs: | |
| artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} | |
| defaults: | |
| run: | |
| shell: bash | |
| steps: | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - id: set-matrix | |
| run: | | |
| artifacts=() | |
| # Build every package if not a workflow call | |
| if [[ ${${{ inputs.is_workflow_call }}:-0} == "0" ]]; then | |
| artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") | |
| else | |
| if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then | |
| artifacts+="'jax'" | |
| fi | |
| if [[ ${${{ inputs.build_jaxlib }}:-0} == "1" ]]; then | |
| artifacts+=", 'jaxlib'" | |
| fi | |
| if [[ ${${{ inputs.build_jax_cuda_pjrt }}:-0} == "1" ]]; then | |
| artifacts+=", 'jax-cuda-pjrt'" | |
| fi | |
| if [[ ${${{ inputs.build_jax_cuda_plugin }}:-0} == "1" ]]; then | |
| artifacts+=", 'jax-cuda-plugin'" | |
| fi | |
| fi | |
| echo "artifact_matrix=[${artifacts[@]}]" >> $GITHUB_OUTPUT | |
| build_artifacts: | |
| needs: determine_matrix | |
| continue-on-error: true | |
| defaults: | |
| run: | |
| # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | |
| shell: bash | |
| strategy: | |
| matrix: | |
| runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] | |
| artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} | |
| python: ["3.10", "3.11", "3.12"] | |
| # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each | |
| # Python version. | |
| exclude: | |
| # Pure Python packages do not need to be built for each Python version. | |
| - artifact: "jax-cuda-pjrt" | |
| python: "3.10" | |
| - artifact: "jax-cuda-pjrt" | |
| python: "3.11" | |
| - artifact: "jax" | |
| python: "3.10" | |
| - artifact: "jax" | |
| python: "3.11" | |
| # jax is a pure Python package so it does not need to be built on multiple platforms. | |
| - artifact: "jax" | |
| runner: "windows-x86-n2-64" | |
| - artifact: "jax" | |
| runner: "linux-arm64-t2a-16" | |
| # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. | |
| - artifact: "jax-cuda-plugin" | |
| runner: "windows-x86-n2-64" | |
| - artifact: "jax-cuda-pjrt" | |
| runner: "windows-x86-n2-64" | |
| runs-on: ${{ matrix.runner }} | |
| container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
| (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
| (contains(matrix.runner, 'windows-x86') && null) }} | |
| env: | |
| JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" | |
| JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | |
| steps: | |
| - uses: actions/checkout@v3 | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - name: Build ${{ matrix.artifact }} | |
| run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" | |
| - name: Set Platform | |
| run: | | |
| echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV | |
| - name: Upload artifacts to GCS bucket | |
| # Upload if requested and one of the artifacts was built | |
| if: inputs.upload_artifacts | |
| run: gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM | |