Test CI scripts and workflows #213
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Build JAX Artifacts | |
| on: | |
| pull_request: | |
| branches: | |
| - main | |
| workflow_dispatch: | |
| inputs: | |
| halt-for-connection: | |
| description: 'Should this workflow run wait for a remote connection?' | |
| type: choice | |
| required: true | |
| default: 'no' | |
| options: | |
| - 'yes' | |
| - 'no' | |
| workflow_call: | |
| inputs: | |
| build_jax: | |
| description: "Should the jax artifact be built?" | |
| required: true | |
| default: true | |
| type: boolean | |
| build_jaxlib: | |
| description: "Should the jaxlib artifact be built?" | |
| required: true | |
| default: true | |
| type: boolean | |
| build_jax_cuda_plugin: | |
| description: "Should the jax-cuda-plugin artifact be built?" | |
| required: true | |
| default: true | |
| type: boolean | |
| build_jax_cuda_pjrt: | |
| description: "Should the jax-cuda-pjrt artifact be built?" | |
| required: true | |
| default: true | |
| type: boolean | |
| clone_main_xla: | |
| description: "Should latest XLA be used? (1 to enable, 0 to disable)" | |
| type: string | |
| required: true | |
| default: "0" | |
| upload_artifacts: | |
| description: "Should the artifacts be uploaded to a GCS bucket?" | |
| required: true | |
| default: false | |
| type: boolean | |
| upload_destination: | |
| description: "GCS location to where the artifacts should be uploaded" | |
| required: true | |
| default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
| type: string | |
| jobs: | |
| build: | |
| continue-on-error: true | |
| defaults: | |
| run: | |
| # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | |
| shell: bash | |
| strategy: | |
| matrix: | |
| runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] | |
| artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] | |
| python: ["3.10"] #, "3.11", "3.12"] | |
| # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each | |
| # Python version. | |
| exclude: | |
| # Pure Python packages do not need to be built for each Python version. | |
| - artifact: "jax-cuda-pjrt" | |
| python: "3.10" | |
| - artifact: "jax-cuda-pjrt" | |
| python: "3.11" | |
| - artifact: "jax" | |
| python: "3.10" | |
| - artifact: "jax" | |
| python: "3.11" | |
| # jax is a pure Python package so it does not need to be built on multiple platforms. | |
| - artifact: "jax" | |
| runner: "windows-x86-n2-64" | |
| - artifact: "jax" | |
| runner: "linux-arm64-t2a-16" | |
| # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. | |
| - artifact: "jax-cuda-plugin" | |
| runner: "windows-x86-n2-64" | |
| - artifact: "jax-cuda-pjrt" | |
| runner: "windows-x86-n2-64" | |
| runs-on: ${{ matrix.runner }} | |
| container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
| (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
| (contains(matrix.runner, 'windows-x86') && null) }} | |
| env: | |
| JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" | |
| JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | |
| steps: | |
| - uses: actions/checkout@v3 | |
| # Halt for testing | |
| - name: Wait For Connection | |
| uses: google-ml-infra/actions/ci_connection@main | |
| with: | |
| halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
| - name: Build jax | |
| if: inputs.build_jax && matrix.artifact == 'jax' | |
| run: ./ci/build_artifacts.sh "jax" | |
| - name: Build jaxlib | |
| if: inputs.build_jaxlib && matrix.artifact == 'jaxlib' | |
| run: ./ci/build_artifacts.sh "jaxlib" | |
| - name: Build jax-cuda-plugin | |
| if: inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin' | |
| run: ./ci/build_artifacts.sh "jax-cuda-plugin" | |
| - name: Build jax-cuda-pjrt | |
| if: inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt' | |
| run: ./ci/build_artifacts.sh "jax-cuda-pjrt" | |
| - name: Upload artifacts to GCS bucket | |
| if: inputs.upload_artifacts | |
| run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}" | |