|
| 1 | +name: Build JAX Artifacts |
| 2 | + |
| 3 | +on: |
| 4 | + # pull_request: |
| 5 | + # branches: |
| 6 | + # - main |
| 7 | + workflow_dispatch: |
| 8 | + inputs: |
| 9 | + halt-for-connection: |
| 10 | + description: 'Should this workflow run wait for a remote connection?' |
| 11 | + type: choice |
| 12 | + required: true |
| 13 | + default: 'no' |
| 14 | + options: |
| 15 | + - 'yes' |
| 16 | + - 'no' |
| 17 | + workflow_call: |
| 18 | + |
| 19 | +jobs: |
| 20 | + build: |
| 21 | + continue-on-error: true |
| 22 | + defaults: |
| 23 | + run: |
| 24 | + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. |
| 25 | + shell: bash |
| 26 | + strategy: |
| 27 | + matrix: |
| 28 | + runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] |
| 29 | + artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] |
| 30 | + python: ["3.10", "3.11", "3.12"] |
| 31 | + # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each |
| 32 | + # Python version. |
| 33 | + exclude: |
| 34 | + # Pure Python packages do not need to be built for each Python version. |
| 35 | + - artifact: "jax-cuda-pjrt" |
| 36 | + python: "3.10" |
| 37 | + - artifact: "jax-cuda-pjrt" |
| 38 | + python: "3.11" |
| 39 | + - artifact: "jax" |
| 40 | + python: "3.10" |
| 41 | + - artifact: "jax" |
| 42 | + python: "3.11" |
| 43 | + # jax is a pure Python package so it does not need to be built on multiple platforms. |
| 44 | + - artifact: "jax" |
| 45 | + runner: "windows-x86-n2-64" |
| 46 | + - artifact: "jax" |
| 47 | + runner: "linux-arm64-t2a-16" |
| 48 | + # jax-cuda-plugin and jax-cuda-pjrt are not supported on Windows. |
| 49 | + - artifact: "jax-cuda-plugin" |
| 50 | + runner: "windows-x86-n2-64" |
| 51 | + - artifact: "jax-cuda-pjrt" |
| 52 | + runner: "windows-x86-n2-64" |
| 53 | + |
| 54 | + runs-on: ${{ matrix.runner }} |
| 55 | + |
| 56 | + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
| 57 | + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') || |
| 58 | + (contains(matrix.runner, 'windows-x86') && null) }} |
| 59 | + |
| 60 | + env: |
| 61 | + # Do not run Docker container for Linux runners. Linux runners already run in a Docker container. |
| 62 | + JAXCI_RUN_DOCKER_CONTAINER: 0 |
| 63 | + |
| 64 | + steps: |
| 65 | + - uses: actions/checkout@v3 |
| 66 | + # Halt for testing |
| 67 | + - name: Wait For Connection |
| 68 | + uses: google-ml-infra/actions/ci_connection@main |
| 69 | + with: |
| 70 | + halt-dispatch-input: ${{ inputs.halt-for-connection }} |
| 71 | + - name: Build ${{ matrix.artifact }} |
| 72 | + env: |
| 73 | + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" |
| 74 | + run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" |
0 commit comments