@@ -37,40 +37,39 @@ jobs:
3737 LIBTPU_OLDEST_VERSION_DATE : 20240722
3838 ENABLE_PJRT_COMPATIBILITY : ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
3939 JAXCI_HERMETIC_PYTHON_VERSION : ${{ matrix.python }}
40+ JAXCI_PYTHON : python${{ matrix.python }}
4041
4142 steps :
4243 - uses : actions/checkout@v3
43- # Halt for testing
44- - name : Wait For Connection
45- uses : google-ml-infra/actions/ci_connection@main
46- with :
47- halt-dispatch-input : ${{ inputs.halt-for-connection }}
4844 - name : Install JAX test requirements
49- env :
50- JAXCI_PYTHON : python${{ matrix.python }}
5145 run : |
52- pip install -U -r build/test-requirements.txt
53- pip install -U -r build/collect-profile-requirements.txt
46+ $JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
47+ $JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5448 - name : Install JAX
5549 run : |
5650 if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
57- pip install .[tpu] \
51+ $JAXCI_PYTHON -m pip install .[tpu] \
5852 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5953 elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
60- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
61- pip install --pre libtpu \
54+ $JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
55+ $JAXCI_PYTHON -m pip install --pre libtpu \
6256 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
63- pip install requests
57+ $JAXCI_PYTHON -m pip install requests
6458 elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65- pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
59+ $JAXCI_PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6660 # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67- pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
61+ $JAXCI_PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6862 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69- pip install requests
63+ $JAXCI_PYTHON -m pip install requests
7064 else
7165 echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7266 exit 1
7367 fi
68+ # Halt for testing
69+ - name : Wait For Connection
70+ uses : google-ml-infra/actions/ci_connection@main
71+ with :
72+ halt-dispatch-input : ${{ inputs.halt-for-connection }}
7473 - name : Run Pytest TPU tests
7574 env :
7675 JAXCI_TPU_CORES : ${{ matrix.tpu.cores }}
0 commit comments