Skip to content

Commit 68ea9eb

Browse files
committed
Use the correct Python binary to install deps
1 parent bd34770 commit 68ea9eb

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

.github/workflows/pytest_tpu.yml

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,40 +37,39 @@ jobs:
3737
LIBTPU_OLDEST_VERSION_DATE: 20240722
3838
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
3939
JAXCI_HERMETIC_PYTHON_VERSION: ${{ matrix.python }}
40+
JAXCI_PYTHON: python${{ matrix.python }}
4041

4142
steps:
4243
- uses: actions/checkout@v3
43-
# Halt for testing
44-
- name: Wait For Connection
45-
uses: google-ml-infra/actions/ci_connection@main
46-
with:
47-
halt-dispatch-input: ${{ inputs.halt-for-connection }}
4844
- name: Install JAX test requirements
49-
env:
50-
JAXCI_PYTHON: python${{ matrix.python }}
5145
run: |
52-
pip install -U -r build/test-requirements.txt
53-
pip install -U -r build/collect-profile-requirements.txt
46+
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
47+
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5448
- name: Install JAX
5549
run: |
5650
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
57-
pip install .[tpu] \
51+
$JAXCI_PYTHON -m pip install .[tpu] \
5852
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5953
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
60-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
61-
pip install --pre libtpu \
54+
$JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
55+
$JAXCI_PYTHON -m pip install --pre libtpu \
6256
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
63-
pip install requests
57+
$JAXCI_PYTHON -m pip install requests
6458
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
59+
$JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6660
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67-
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
61+
$JAXCI_PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6862
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69-
pip install requests
63+
$JAXCI_PYTHON -m pip install requests
7064
else
7165
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7266
exit 1
7367
fi
68+
# Halt for testing
69+
- name: Wait For Connection
70+
uses: google-ml-infra/actions/ci_connection@main
71+
with:
72+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
7473
- name: Run Pytest TPU tests
7574
env:
7675
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}

0 commit comments

Comments
 (0)