|
62 | 62 | strategy: |
63 | 63 | matrix: |
64 | 64 | runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] |
65 | | - artifact: ["jax", "jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] |
| 65 | + artifact: >- |
| 66 | + ${{ fromJSON((inputs.build_jax && '["jax"]') || '[]') }} + |
| 67 | + ${{ fromJSON((inputs.build_jaxlib && '["jaxlib"]') || '[]') }} + |
| 68 | + ${{ fromJSON((inputs.build_jax_cuda_pjrt && '["jax-cuda-pjrt"]') || '[]') }} + |
| 69 | + ${{ fromJSON((inputs.build_jax_cuda_plugin && '["jax-cuda-plugin"]') || '[]') }} |
66 | 70 | python: ["3.10"] #, "3.11", "3.12"] |
67 | 71 | # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each |
68 | 72 | # Python version. |
@@ -97,46 +101,20 @@ jobs: |
97 | 101 | JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" |
98 | 102 | JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" |
99 | 103 |
|
100 | | - if: inputs.build_jax || inputs.build_jaxlib || inputs.build_jax_cuda_plugin || inputs.build_jax_cuda_pjrt |
101 | 104 | steps: |
102 | 105 | - uses: actions/checkout@v3 |
103 | | - if: >- |
104 | | - (inputs.build_jax && matrix.artifact == 'jax') || |
105 | | - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || |
106 | | - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || |
107 | | - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') |
108 | 106 | # Halt for testing |
109 | 107 | - name: Wait For Connection |
110 | 108 | uses: google-ml-infra/actions/ci_connection@main |
111 | 109 | with: |
112 | 110 | halt-dispatch-input: ${{ inputs.halt-for-connection }} |
113 | | - if: >- |
114 | | - (inputs.build_jax && matrix.artifact == 'jax') || |
115 | | - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || |
116 | | - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || |
117 | | - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') |
118 | 111 | - name: Build ${{ matrix.artifact }} |
119 | | - if: >- |
120 | | - (inputs.build_jax && matrix.artifact == 'jax') || |
121 | | - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || |
122 | | - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || |
123 | | - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') |
124 | 112 | run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" |
125 | 113 | - name: Set Platform |
126 | | - if: >- |
127 | | - (inputs.build_jax && matrix.artifact == 'jax') || |
128 | | - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || |
129 | | - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || |
130 | | - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') |
131 | 114 | run: | |
132 | 115 | echo "PLATFORM=$(uname)_$(uname -m)" >> $GITHUB_ENV |
133 | 116 | - name: Upload artifacts to GCS bucket |
134 | 117 | # Upload if requested and one of the artifacts was built |
135 | | - if: >- |
136 | | - inputs.upload_artifacts && |
137 | | - (inputs.build_jax && matrix.artifact == 'jax') || |
138 | | - (inputs.build_jaxlib && matrix.artifact == 'jaxlib') || |
139 | | - (inputs.build_jax_cuda_plugin && matrix.artifact == 'jax-cuda-plugin') || |
140 | | - (inputs.build_jax_cuda_pjrt && matrix.artifact == 'jax-cuda-pjrt') |
| 118 | + if: inputs.upload_artifacts |
141 | 119 | run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM |
142 | 120 |
|
0 commit comments