|
53 | 53 | type: string |
54 | 54 |
|
55 | 55 | jobs: |
| 56 | + determine_matrix: |
| 57 | + runs-on: linux-x86-n2-16 |
| 58 | + container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') |
| 59 | + outputs: |
| 60 | + build_matrix: ${{ steps.set-matrix.outputs.build_matrix }} |
| 61 | + steps: |
| 62 | + - id: set-matrix |
| 63 | + run: | |
| 64 | + matrix='[]' |
| 65 | + if ${{ inputs.build_jax }}; then |
| 66 | + matrix='["jax"]' |
| 67 | + if ${{ inputs.build_jaxlib }}; then |
| 68 | + matrix='["jax", "jaxlib"]' |
| 69 | + if ${{ inputs.build_jax_cuda_pjrt }}; then |
| 70 | + matrix='["jax", "jaxlib", "jax-cuda-pjrt"]' |
| 71 | + if ${{ inputs.build_jax_cuda_plugin }}; then |
| 72 | + matrix='["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]' |
| 73 | + fi |
| 74 | + fi |
| 75 | + fi |
| 76 | + fi |
| 77 | + echo "build_matrix=${matrix}" >> $GITHUB_OUTPUT |
| 78 | +
|
56 | 79 | build_artifacts: |
57 | 80 | continue-on-error: true |
58 | 81 | defaults: |
|
62 | 85 | strategy: |
63 | 86 | matrix: |
64 | 87 | runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] |
65 | | - artifact: >- |
66 | | - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && inputs.build_jax_cuda_plugin && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"]') }} || |
67 | | - ${{ inputs.build_jax && inputs.build_jaxlib && inputs.build_jax_cuda_pjrt && fromJSON('["jax", "jaxlib", "jax-cuda-pjrt"]') }} || |
68 | | - ${{ inputs.build_jax && inputs.build_jaxlib && fromJSON('["jax", "jaxlib"]') }} || |
69 | | - ${{ inputs.build_jax && fromJSON('["jax"]') }} || |
70 | | - ${{ fromJSON('[]') }} |
| 88 | + artifact: ${{ fromJSON(needs.determine_matrix.outputs.build_matrix) }} |
71 | 89 | python: ["3.10"] #, "3.11", "3.12"] |
72 | 90 | # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each |
73 | 91 | # Python version. |
|
0 commit comments