|
51 | 51 | required: true |
52 | 52 | default: '${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' |
53 | 53 | type: string |
| 54 | + is_workflow_call: |
| 55 | + description: "Metadata variable to know whether a workflow call was made" |
| 56 | + type: string |
| 57 | + required: true |
| 58 | + default: "1" |
54 | 59 |
|
55 | 60 | jobs: |
56 | 61 | determine_matrix: |
|
70 | 75 | - id: set-matrix |
71 | 76 | run: | |
72 | 77 | artifacts=() |
73 | | - if [[ ${{ github.event_name }} == "pull_request" ]]; then |
| 78 | + # Build every package if not a workflow call |
| 79 | + if [[ ${${{ inputs.is_workflow_call }}:-0} == "0" ]]; then |
74 | 80 | artifacts=("'jax'" ", 'jaxlib'" ", 'jax-cuda-pjrt'" ", 'jax-cuda-plugin'") |
75 | 81 | else |
76 | 82 | if [[ ${${{ inputs.build_jax }}:-0} == "1" ]]; then |
@@ -102,7 +108,7 @@ jobs: |
102 | 108 | matrix: |
103 | 109 | runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"] |
104 | 110 | artifact: ${{ fromJSON(needs.determine_matrix.outputs.artifact_matrix) }} |
105 | | - python: ["3.10"] #, "3.11", "3.12"] |
| 111 | + python: ["3.10", "3.11", "3.12"] |
106 | 112 | # jax-cuda-pjrt and jax are pure Python packages so they do not need to be built for each |
107 | 113 | # Python version. |
108 | 114 | exclude: |
@@ -151,5 +157,5 @@ jobs: |
151 | 157 | - name: Upload artifacts to GCS bucket |
152 | 158 | # Upload if requested and one of the artifacts was built |
153 | 159 | if: inputs.upload_artifacts |
154 | | - run: ~/usr/local/bin/google-cloud-sdk/bin/gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM |
| 160 | + run: gsutil -m cp -r $(pwd)/dist gs://general-ml-ci-transient/jax-github-actions/"${{ inputs.upload_destination }}"/$PLATFORM |
155 | 161 |
|
0 commit comments