@@ -59,42 +59,38 @@ jobs:
5959 git config --global --add safe.directory "$GITHUB_WORKSPACE"
6060 - name : Install JAX test requirements
6161 run : |
62- $PYTHON -m pip install -U -r build/test-requirements.txt
63- $PYTHON -m pip install -U -r build/collect-profile-requirements.txt
62+ $PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
6463 - name : Install JAX
6564 run : |
66- $PYTHON -m pip uninstall -y jax jaxlib libtpu
65+ $PYTHON -m uv pip uninstall jax jaxlib libtpu
6766 if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
6867 # Build and install jaxlib at head
6968 $PYTHON build/build.py build --wheels=jaxlib \
7069 --bazel_options=--config=rbe_linux_x86_64 \
7170 --local_xla_path="$(pwd)/xla" \
7271 --verbose
7372
74- $PYTHON -m pip install dist/*.whl
75-
76- # Install "jax" at head
77- $PYTHON -m pip install -U -e .
78-
79- # Install libtpu
80- $PYTHON -m pip install --pre libtpu \
81- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
73+ # Install jaxlib, "jax" at head, and libtpu
74+ $PYTHON -m uv pip install dist/*.whl \
75+ -U -e . \
76+ --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8277 elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
83- $PYTHON -m pip install .[tpu] \
78+ $PYTHON -m uv pip install .[tpu] \
8479 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8580
8681 elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
87- $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
88- $PYTHON -m pip install --pre libtpu \
89- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
90- $PYTHON -m pip install requests
82+ $PYTHON -m uv pip install \
83+ --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
84+ libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
85+ requests
9186
9287 elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
93- $PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9488 # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
95- $PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
96- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
97- $PYTHON -m pip install requests
89+ $PYTHON -m uv pip install \
90+ --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
91+ libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
92+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
93+ requests
9894 else
9995 echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
10096 exit 1
0 commit comments