Skip to content

Commit 50ef46e

Browse files
committed
Remove windows specific download steps
1 parent 0b92e10 commit 50ef46e

File tree

5 files changed

+33
-45
lines changed

5 files changed

+33
-45
lines changed

.github/workflows/bazel_cuda_non_rbe.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ on:
3131
gcs_download_uri:
3232
description: "GCS location URI from where the artifacts should be downloaded"
3333
required: true
34-
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
34+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
3535
type: string
3636

3737

@@ -69,8 +69,8 @@ jobs:
6969
- name: Download the wheel artifacts from GCS
7070
run: >-
7171
mkdir -p $(pwd)/dist &&
72-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
73-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
74-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/
72+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
73+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.wh"l" $(pwd)/dist/ &&
74+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
7575
- name: Run Bazel tests
7676
run: ./ci/run_bazel_test_gpu_non_rbe.sh

.github/workflows/build_artifacts.yml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,20 @@ on:
5353
type: string
5454
required: false
5555
default: "0"
56-
upload_artifacts:
56+
upload_artifacts_to_gcs:
5757
description: "Should the artifacts be uploaded to a GCS bucket?"
5858
required: false
5959
default: false
6060
type: boolean
6161
gcs_upload_uri:
6262
description: "GCS location URI to where the artifacts should be uploaded"
6363
required: false
64-
default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
64+
default: ' gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
6565
type: string
66+
outputs:
67+
gcs_upload_uri:
68+
description: "GCS location prefix to where the artifacts were uploaded"
69+
value: ${{ inputs.gcs_upload_uri }}
6670

6771
jobs:
6872
build_artifacts:
@@ -95,12 +99,7 @@ jobs:
9599
halt-dispatch-input: ${{ inputs.halt-for-connection }}
96100
- name: Build ${{ inputs.artifact }}
97101
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
98-
- name: Upload artifacts to GCS bucket (non-Windows)
99-
if: >-
100-
${{ inputs.upload_artifacts && !contains(inputs.runner, 'windows-x86') }}
101-
run: gsutil -m cp -r $(pwd)/dist/*.whl "${{ inputs.gcs_upload_uri }}"/
102-
- name: Upload artifacts to GCS bucket (Windows)
102+
- name: Upload artifacts to GCS bucket
103103
if: >-
104-
${{ inputs.upload_artifacts && contains(inputs.runner, 'windows-x86') }}
105-
shell: cmd
106-
run: gsutil -m cp -r dist/*.whl "${{ inputs.gcs_upload_uri }}"/
104+
${{ inputs.upload_artifacts_to_gcs }}
105+
run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/

.github/workflows/pytest_cpu.yml

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ on:
3939
gcs_download_uri:
4040
description: "GCS location URI from where the artifacts should be downloaded"
4141
required: false
42-
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
42+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
4343
type: string
4444

4545
jobs:
@@ -50,8 +50,8 @@ jobs:
5050
shell: bash
5151

5252
runs-on: ${{ inputs.runner }}
53-
container: ${{ (contains(inputs.runner, 'linux-x86-n2') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
54-
(contains(inputs.runner, 'linux-x86-t2a') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
53+
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
54+
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
5555
(contains(inputs.runner, 'windows-x86') && null) }}
5656

5757
name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
@@ -88,26 +88,14 @@ jobs:
8888
echo "OS=${os}" >> $GITHUB_ENV
8989
echo "ARCH=${arch}" >> $GITHUB_ENV
9090
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
91-
- name: Download wheel artifacts from GCS (non-Windows runs)
92-
if: ${{ !contains(inputs.runner, 'windows-x86') }}
91+
- name: Download wheel artifacts from GCS
9392
run: |
94-
mkdir -p $(pwd)/dist &&
95-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/
93+
mkdir -p dist &&
94+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/
9695
9796
# Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1
9897
if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then
99-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/
100-
fi
101-
- name: Download the jaxlib wheel from GCS (Windows runs)
102-
if: ${{ contains(inputs.runner, 'windows-x86') }}
103-
shell: cmd
104-
run: >-
105-
mkdir dist &&
106-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl dist/
107-
108-
# Download the "jax" wheel from GCS if inputs.install_latest_jax is not set to 1
109-
if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then
110-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*.whl $(pwd)/dist/
98+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" dist/
11199
fi
112100
- name: Install dependencies
113101
run: $JAXCI_PYTHON -m pip install -r build/requirements.in

.github/workflows/pytest_cuda.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ on:
4444
gcs_download_uri:
4545
description: "GCS location URI from where the artifacts should be downloaded"
4646
required: true
47-
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
47+
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
4848
type: string
4949

5050
jobs:
@@ -84,13 +84,13 @@ jobs:
8484
- name: Download the wheel artifacts from GCS
8585
run: |
8686
mkdir -p $(pwd)/dist &&
87-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
88-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl $(pwd)/dist/ &&
89-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*cuda*pjrt*${OS}*${ARCH}*.whl $(pwd)/dist/
87+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
88+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
89+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
9090
9191
# Download the "jax" wheel from GCS if
9292
if [[ ${{ inputs.install_jax_current_commit }} != 1 ]]; then
93-
gsutil -m cp -r "${{ inputs.gcs_download_uri }}"/jax*py3*none*any*.whl $(pwd)/dist/
93+
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*py3*none*any*.whl" $(pwd)/dist/
9494
fi
9595
- name: Install dependencies
9696
run: $JAXCI_PYTHON -m pip install -r build/requirements.in

.github/workflows/wheel_tests_continuous.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ jobs:
3535
artifact: ${{ matrix.artifact }}
3636
python: ${{ matrix.python }}
3737
clone_main_xla: 1
38-
upload_artifacts: true
39-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
38+
upload_artifacts_to_gcs: true
39+
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
4040

4141
build_cuda_artifacts:
4242
uses: ./.github/workflows/build_artifacts.yml
@@ -52,8 +52,8 @@ jobs:
5252
artifact: ${{ matrix.artifact }}
5353
python: ${{ matrix.python }}
5454
clone_main_xla: 1
55-
upload_artifacts: true
56-
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
55+
upload_artifacts_to_gcs: true
56+
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
5757

5858
run_pytest_cpu:
5959
needs: build_jaxlib_artifact
@@ -69,7 +69,7 @@ jobs:
6969
runner: ${{ matrix.runner }}
7070
python: ${{ matrix.python }}
7171
enable-x64: ${{ matrix.enable-x64 }}
72-
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
72+
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
7373
install_jax_current_commit: "1"
7474

7575

@@ -89,7 +89,8 @@ jobs:
8989
python: ${{ matrix.python }}
9090
cuda: ${{ matrix.cuda }}
9191
enable-x64: ${{ matrix.enable-x64 }}
92-
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
92+
# GCS upload URI is the same for both artifact build jobs
93+
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}
9394
install_jax_current_commit: "1"
9495

9596
run_bazel_test_gpu:
@@ -106,4 +107,4 @@ jobs:
106107
runner: ${{ matrix.runner }}
107108
python: ${{ matrix.python }}
108109
enable-x64: ${{ matrix.enable-x64 }}
109-
gcs_download_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}'
110+
gcs_download_uri: ${{ needs.build_jaxlib_artifact.outputs.gcs_upload_uri }}

0 commit comments

Comments
 (0)