|
16 | 16 | - 'no' |
17 | 17 | workflow_call: |
18 | 18 | inputs: |
| 19 | + runner: |
| 20 | + description: "Which runner should the workflow run on?" |
| 21 | + type: string |
| 22 | + required: true |
| 23 | + default: "linux-x86-n2-16" |
| 24 | + artifact: |
| 25 | + description: "Which JAX artifact to build?" |
| 26 | + type: string |
| 27 | + required: true |
| 28 | + default: "jaxlib" |
| 29 | + python-version: |
| 30 | + description: "Which python version should the artifact be built for?" |
| 31 | + type: string |
| 32 | + required: true |
| 33 | + default: "3.12" |
19 | 34 | clone_main_xla: |
20 | 35 | description: "Should latest XLA be used? (1 to enable, 0 to disable)" |
21 | 36 | type: string |
|
33 | 48 | type: string |
34 | 49 |
|
35 | 50 | jobs: |
36 | | - determine_matrix: |
37 | | - runs-on: "linux-x86-n2-16" |
38 | | - container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest" |
39 | | - outputs: |
40 | | - artifact_matrix: ${{ steps.set-matrix.outputs.artifact_matrix }} |
41 | | - python_matrix: ${{ steps.set-matrix.outputs.python_matrix }} |
42 | | - platform_matrix: ${{ steps.set-matrix.outputs.platform_matrix }} |
43 | | - defaults: |
44 | | - run: |
45 | | - shell: bash |
46 | | - steps: |
47 | | - # Halt for testing |
48 | | - - name: Wait For Connection |
49 | | - uses: google-ml-infra/actions/ci_connection@main |
50 | | - with: |
51 | | - halt-dispatch-input: ${{ inputs.halt-for-connection }} |
52 | | - - name: "Determine the matrix" |
53 | | - id: set-matrix |
54 | | - run: | |
55 | | - echo ${{ matrix.workflow_call_runner }} |
56 | | -
|
57 | 51 | build_artifacts: |
58 | | - needs: determine_matrix |
59 | 52 | defaults: |
60 | 53 | run: |
61 | 54 | # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. |
62 | 55 | shell: bash |
63 | 56 | strategy: |
64 | 57 | fail-fast: false # don't cancel all jobs on failure |
65 | 58 | matrix: |
66 | | - runner: ${{ fromJSON(needs.determine_matrix.outputs.platform_matrix) }} |
| 59 | + runner: ${{ }} |
67 | 60 | artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} |
68 | 61 | python: ${{ fromJSON(needs.determine_matrix.outputs.python_matrix) }} |
69 | 62 | exclude: |
|
73 | 66 | - runner: "windows-x86-n2-16" |
74 | 67 | artifact: "jax-cuda-plugin" |
75 | 68 |
|
76 | | - runs-on: ${{ matrix.runner }} |
| 69 | + runs-on: ${{ inputs.runner }} |
77 | 70 |
|
78 | 71 | container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
79 | 72 | (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || |
|
0 commit comments