|
39 | 39 | gcs_download_uri: |
40 | 40 | description: "GCS location URI from where the artifacts should be downloaded" |
41 | 41 | required: false |
42 | | - default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}' |
| 42 | + default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' |
43 | 43 | type: string |
44 | 44 |
|
45 | 45 | jobs: |
|
50 | 50 | shell: bash |
51 | 51 |
|
52 | 52 | runs-on: ${{ inputs.runner }} |
53 | | - container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
54 | | - (contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || |
| 53 | + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || |
| 54 | + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || |
55 | 55 | (contains(inputs.runner, 'windows-x86') && null) }} |
56 | 56 |
|
57 | 57 | name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" |
@@ -88,26 +88,14 @@ jobs: |
88 | 88 | echo "OS=${os}" >> $GITHUB_ENV |
89 | 89 | echo "ARCH=${arch}" >> $GITHUB_ENV |
90 | 90 | echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV |
91 | | - - name: Download wheel artifacts from GCS (non-Windows runs) |
92 | | - if: ${{ !contains(inputs.runner, 'windows-x86') }} |
| 91 | + - name: Download wheel artifacts from GCS |
93 | 92 | run: | |
94 | | - mkdir -p $(pwd)/dist && |
95 | | - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ |
| 93 | + mkdir -p dist && |
| 94 | + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/ |
96 | 95 | |
97 | 96 | # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 |
98 | 97 | if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then |
99 | | - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/ |
100 | | - fi |
101 | | - - name: Download the jaxlib wheel from GCS (Windows runs) |
102 | | - if: ${{ contains(inputs.runner, 'windows-x86') }} |
103 | | - shell: cmd |
104 | | - run: >- |
105 | | - mkdir dist && |
106 | | - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/ |
107 | | - |
108 | | - # Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1 |
109 | | - if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then |
110 | | - gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/ |
| 98 | + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" dist/ |
111 | 99 | fi |
112 | 100 | - name: Install dependencies |
113 | 101 | run: $JAXCI_PYTHON -m pip install -r build/requirements.in |
|
0 commit comments