Skip to content

Commit 9300174

Browse files
authored
Merge branch 'main' into srnitin/test-scripts-workflows
2 parents b28ff8a + cf81f65 commit 9300174

File tree

208 files changed

+10863
-4364
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

208 files changed

+10863
-4364
lines changed

.bazelrc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
# #############################################################################
22
# All default build options below. These apply to all build commands.
33
# #############################################################################
4+
# TODO: Enable Bzlmod
5+
common --noenable_bzlmod
6+
7+
# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260
8+
common --noincompatible_enable_cc_toolchain_resolution
9+
410
# Make Bazel print out all options from rc files.
511
common --announce_rc
612

@@ -41,6 +47,13 @@ build:linux --copt=-Wno-array-parameter
4147
build:macos --config=posix
4248
build:macos --apple_platform_type=macos
4349

50+
# Bazel 7.0.0 no longer supports dynamic symbol lookup on macOS. To resolve
51+
# undefined symbol errors in macOS arm64 builds, explicitly add the necessary
52+
# linker flags until dependencies are well defined. See
53+
# https://github.com/bazelbuild/bazel/issues/19730.
54+
build:macos --linkopt=-Wl,-undefined,dynamic_lookup
55+
build:macos --host_linkopt=-Wl,-undefined,dynamic_lookup
56+
4457
# Windows has a relatively short command line limit, which JAX has begun to hit.
4558
# See https://docs.bazel.build/versions/main/windows.html
4659
build:windows --features=compiler_param_file
@@ -270,6 +283,12 @@ build:resultstore --bes_instance_name="tensorflow-testing"
270283
build:resultstore --bes_results_url="https://source.cloud.google.com/results/invocations"
271284
build:resultstore --bes_timeout=600s
272285

286+
# Configs for RBE cache. When using resultstore, we need to use these configs
287+
# as well to ensure that the logs that get uploaded to resultstore can be read
288+
# without any errors.
289+
build:rbe_cache --remote_cache=remotebuildexecution.googleapis.com
290+
build:rbe_cache --remote_instance_name=projects/tensorflow-testing/instances/default_instance
291+
273292
build:rbe --config=resultstore
274293
build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
275294
build:rbe --define=EXECUTOR=remote

.bazelversion

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6.5.0
1+
7.4.1

.github/workflows/asan.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ on:
1717
jobs:
1818
asan:
1919
# Don't execute in fork due to runner type
20-
if: github.repository == 'jax-ml/jax'
20+
if: ${{ github.repository == 'jax-ml/jax' }}
2121
runs-on: linux-x86-n2-64
2222
container:
2323
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
@@ -62,7 +62,8 @@ jobs:
6262
run: |
6363
source ${GITHUB_WORKSPACE}/venv/bin/activate
6464
cd jax
65-
pip install -r build/test-requirements.txt
65+
pip install uv~=0.5.30
66+
uv pip install -r build/test-requirements.txt
6667
- name: Build and install JAX
6768
env:
6869
ASAN_OPTIONS: detect_leaks=0
@@ -73,8 +74,8 @@ jobs:
7374
--bazel_options=--color=yes \
7475
--bazel_options=--copt=-fsanitize=address \
7576
--clang_path=/usr/bin/clang-18
76-
pip install dist/jaxlib-*.whl
77-
pip install -e .
77+
uv pip install dist/jaxlib-*.whl \
78+
-e .
7879
- name: Run tests
7980
env:
8081
ASAN_OPTIONS: detect_leaks=0
@@ -89,4 +90,4 @@ jobs:
8990
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
9091
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
9192
# The LD_PRELOAD works around https://github.com/google/sanitizers/issues/934#issuecomment-649516500
92-
LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 python -m pytest -n auto --tb=short --maxfail=20 tests
93+
LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 python -m pytest -n 32 --tb=short --maxfail=20 tests

.github/workflows/bazel_cuda_non_rbe.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
# Explicitly set the shell to bash
4848
shell: bash
4949
runs-on: ${{ inputs.runner }}
50-
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cudnn9.1-cuda12.3:720686788"
50+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest"
5151

5252
env:
5353
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}

.github/workflows/build_artifacts.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ on:
1717
options:
1818
- "linux-x86-n2-16"
1919
- "linux-arm64-c4a-64"
20-
- "windows-x86-n2-16"
20+
- "windows-x86-n2-64"
2121
artifact:
2222
description: "Which JAX artifact to build?"
2323
type: choice
@@ -125,19 +125,19 @@ jobs:
125125
with:
126126
repository: openxla/xla
127127
path: jax/xla
128-
- name: Enable RBE if building on Linux x86 or Windows x86
129-
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
128+
- name: Enable RBE if building on Linux x86
129+
if: contains(inputs.runner, 'linux-x86')
130130
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
131-
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64
132-
if: contains(inputs.runner, 'linux-arm64')
131+
- name: Enable Bazel remote cache (with writes enabled) if building on Linux Aarch64 or Windows x86
132+
if: contains(inputs.runner, 'linux-arm64') || contains(inputs.runner, 'windows-x86')
133133
run: echo "JAXCI_WRITE_TO_BAZEL_REMOTE_CACHE=1" >> $GITHUB_ENV
134134
# Halt for testing
135135
- name: Wait For Connection
136136
uses: google-ml-infra/actions/ci_connection@main
137137
with:
138138
halt-dispatch-input: ${{ inputs.halt-for-connection }}
139139
- name: Build ${{ inputs.artifact }}
140-
timeout-minutes: 30
140+
timeout-minutes: 60
141141
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
142142
- name: Upload artifacts to a GCS bucket (non-Windows runs)
143143
if: >-

.github/workflows/ci-build.yaml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
python-version: ${{ matrix.python-version }}
7676
- name: Install dependencies
7777
run: |
78-
pip install uv
78+
pip install uv~=0.5.30
7979
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
8080
8181
- name: Run tests
@@ -88,7 +88,7 @@ jobs:
8888
JAX_SKIP_SLOW_TESTS: true
8989
PY_COLORS: 1
9090
run: |
91-
pip install -e .
91+
uv pip install --system -e .
9292
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
9393
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
9494
echo "JAX_ENABLE_CUSTOM_PRNG=$JAX_ENABLE_CUSTOM_PRNG"
@@ -113,7 +113,7 @@ jobs:
113113
python-version: ${{ matrix.python-version }}
114114
- name: Install dependencies
115115
run: |
116-
pip install uv
116+
pip install uv~=0.5.30
117117
uv pip install --system -r docs/requirements.txt
118118
- name: Test documentation
119119
env:
@@ -147,7 +147,7 @@ jobs:
147147
python-version: ${{ matrix.python-version }}
148148
- name: Install dependencies
149149
run: |
150-
pip install uv
150+
pip install uv~=0.5.30
151151
uv pip install --system -r docs/requirements.txt
152152
- name: Render documentation
153153
run: |
@@ -173,8 +173,9 @@ jobs:
173173
python-version: ${{ matrix.python-version }}
174174
- name: Install dependencies
175175
run: |
176-
pip install uv
177-
uv pip install --system .[minimum-jaxlib] tensorflow -r build/test-requirements.txt
176+
pip install uv~=0.5.30
177+
uv pip install --system .[minimum-jaxlib] -r build/test-requirements.txt
178+
uv pip install --system --pre tensorflow==2.19.0rc0
178179
179180
- name: Run tests
180181
env:
@@ -184,7 +185,7 @@ jobs:
184185
JAX_SKIP_SLOW_TESTS: true
185186
PY_COLORS: 1
186187
run: |
187-
pip install -e .
188+
uv pip install --system -e .
188189
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
189190
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
190191
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
@@ -205,7 +206,7 @@ jobs:
205206
python-version: 3.12
206207
- name: Install JAX
207208
run: |
208-
pip install uv
209+
pip install uv~=0.5.30
209210
uv pip install --system .[cuda12]
210211
- name: Build and install example project
211212
run: uv pip install --system ./examples/ffi[test]

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,42 +59,38 @@ jobs:
5959
git config --global --add safe.directory "$GITHUB_WORKSPACE"
6060
- name: Install JAX test requirements
6161
run: |
62-
$PYTHON -m pip install -U -r build/test-requirements.txt
63-
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
62+
$PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
6463
- name: Install JAX
6564
run: |
66-
$PYTHON -m pip uninstall -y jax jaxlib libtpu
65+
$PYTHON -m uv pip uninstall jax jaxlib libtpu
6766
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
6867
# Build and install jaxlib at head
6968
$PYTHON build/build.py build --wheels=jaxlib \
7069
--bazel_options=--config=rbe_linux_x86_64 \
7170
--local_xla_path="$(pwd)/xla" \
7271
--verbose
7372
74-
$PYTHON -m pip install dist/*.whl
75-
76-
# Install "jax" at head
77-
$PYTHON -m pip install -U -e .
78-
79-
# Install libtpu
80-
$PYTHON -m pip install --pre libtpu \
81-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
73+
# Install jaxlib, "jax" at head, and libtpu
74+
$PYTHON -m uv pip install dist/*.whl \
75+
-U -e . \
76+
--pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8277
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
83-
$PYTHON -m pip install .[tpu] \
78+
$PYTHON -m uv pip install .[tpu] \
8479
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
8580
8681
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
87-
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
88-
$PYTHON -m pip install --pre libtpu \
89-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
90-
$PYTHON -m pip install requests
82+
$PYTHON -m uv pip install \
83+
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
84+
libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
85+
requests
9186
9287
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
93-
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
9488
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
95-
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
96-
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
97-
$PYTHON -m pip install requests
89+
$PYTHON -m uv pip install \
90+
--pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
91+
libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
92+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
93+
requests
9894
else
9995
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
10096
exit 1

.github/workflows/cloud-tpu-ci-presubmit.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ jobs:
7474
git config --global --add safe.directory "$GITHUB_WORKSPACE"
7575
- name: Install JAX test requirements
7676
run: |
77-
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
78-
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
77+
$JAXCI_PYTHON -m uv pip install -U -r build/test-requirements.txt -r build/collect-profile-requirements.txt
7978
- name: Build jaxlib at head with latest XLA
8079
run: |
8180
# Build and install jaxlib at head
@@ -86,7 +85,7 @@ jobs:
8685
--verbose
8786
8887
# Install libtpu
89-
$JAXCI_PYTHON -m pip install --pre libtpu \
88+
$JAXCI_PYTHON -m uv pip install --pre libtpu \
9089
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
9190
# Halt for testing
9291
- name: Wait For Connection

.github/workflows/jax-array-api.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ jobs:
3737
python-version: ${{ matrix.python-version }}
3838
- name: Install dependencies
3939
run: |
40-
pip install uv
41-
uv pip install --system .[ci]
42-
uv pip install --system pytest-xdist -r array-api-tests/requirements.txt
40+
pip install uv~=0.5.30
41+
uv pip install --system .[ci] pytest-xdist -r array-api-tests/requirements.txt
4342
- name: Run the test suite
4443
env:
4544
ARRAY_API_TESTS_MODULE: jax.numpy

.github/workflows/metal_plugin_ci.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ jobs:
3535
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
3636
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
3737
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
38-
pip install -U pip numpy wheel
39-
pip install absl-py pytest
38+
pip install uv~=0.5.30
39+
uv pip install -U pip numpy wheel absl-py pytest
4040
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
41-
pip install --pre jaxlib \
41+
uv pip install --pre jaxlib \
4242
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
4343
fi;
4444
cd jax
45-
pip install .
46-
pip install jax-metal
45+
uv pip install . jax-metal
4746
- name: Run test
4847
run: |
4948
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate

0 commit comments

Comments
 (0)