Skip to content

Commit f9ce549

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 6fe9a71 + 1372669 commit f9ce549

File tree

109 files changed

+2618
-1265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2618
-1265
lines changed

.bazelrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ build:macos_cache_push --config=macos_cache --remote_upload_local_results=true -
183183
build:ci_linux_x86_64 --config=avx_linux --config=avx_posix
184184
build:ci_linux_x86_64 --config=mkl_open_source_only
185185
build:ci_linux_x86_64 --config=clang --verbose_failures=true
186+
build:ci_linux_x86_64 --color=yes
186187

187188
# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA
188189
# toolchain for both CPU and GPU builds.
@@ -203,6 +204,7 @@ build:ci_linux_x86_64_cuda --config=ci_linux_x86_64
203204
# Linux Aarch64 CI configs
204205
build:ci_linux_aarch64_base --config=clang --verbose_failures=true
205206
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
207+
build:ci_linux_aarch64_base --color=yes
206208

207209
build:ci_linux_aarch64 --config=ci_linux_aarch64_base
208210
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
@@ -221,18 +223,21 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm
221223
build:ci_darwin_x86_64 --macos_minimum_os=10.14
222224
build:ci_darwin_x86_64 --config=macos_cache_push
223225
build:ci_darwin_x86_64 --verbose_failures=true
226+
build:ci_darwin_x86_64 --color=yes
224227

225228
# Mac Arm64 CI configs
226229
build:ci_darwin_arm64 --macos_minimum_os=11.0
227230
build:ci_darwin_arm64 --config=macos_cache_push
228231
build:ci_darwin_arm64 --verbose_failures=true
232+
build:ci_darwin_arm64 --color=yes
229233

230234
# Windows x86 CI configs
231235
build:ci_windows_amd64 --config=avx_windows
232236
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
233237
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
234238
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
235239
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
240+
build:ci_windows_amd64 --color=yes
236241

237242
# #############################################################################
238243
# RBE config options below. These inherit the CI configs above and set the

.github/workflows/asan.yaml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,8 @@ jobs:
2525
run:
2626
shell: bash -l {0}
2727
steps:
28-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
29-
with:
30-
path: jax
31-
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
32-
with:
33-
repository: python/cpython
34-
path: cpython
35-
ref: v3.13.0
28+
# Install git before actions/checkout as otherwise it will download the code with the GitHub
29+
# REST API and therefore any subsequent git commands will fail.
3630
- name: Install clang 18
3731
env:
3832
DEBIAN_FRONTEND: noninteractive
@@ -42,6 +36,14 @@ jobs:
4236
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
4337
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
4438
libffi-dev liblzma-dev
39+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
40+
with:
41+
path: jax
42+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
43+
with:
44+
repository: python/cpython
45+
path: cpython
46+
ref: v3.13.0
4547
- name: Build CPython with ASAN enabled
4648
env:
4749
ASAN_OPTIONS: detect_leaks=0
@@ -65,7 +67,7 @@ jobs:
6567
run: |
6668
source ${GITHUB_WORKSPACE}/venv/bin/activate
6769
cd jax
68-
python build/build.py \
70+
python build/build.py build --wheels=jaxlib --verbose \
6971
--bazel_options=--color=yes \
7072
--bazel_options=--copt=-fsanitize=address \
7173
--clang_path=/usr/bin/clang-18

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
name: CI - Cloud TPU (nightly)
1414
on:
1515
schedule:
16-
- cron: "0 14 * * *" # daily at 7am PST
16+
- cron: "0 */2 * * *" # Run every 2 hours
1717
workflow_dispatch: # allows triggering the workflow run manually
1818
# This should also be set to read-only in the project settings, but it's nice to
1919
# document and enforce the permissions here.
@@ -24,17 +24,20 @@ jobs:
2424
strategy:
2525
fail-fast: false # don't cancel all jobs on failure
2626
matrix:
27-
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
27+
jaxlib-version: ["head", "pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
2828
tpu: [
29-
{type: "v3-8", cores: "4"},
30-
{type: "v4-8", cores: "4"},
31-
{type: "v5e-8", cores: "8"}
29+
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
30+
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
31+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
3232
]
33+
python-version: ["3.10"]
3334
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3435
env:
3536
LIBTPU_OLDEST_VERSION_DATE: 20240722
3637
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
37-
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
38+
PYTHON: python${{ matrix.python-version }}
39+
runs-on: ${{ matrix.tpu.runner }}
40+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
3841
timeout-minutes: 120
3942
defaults:
4043
run:
@@ -44,54 +47,74 @@ jobs:
4447
# mandates using a specific commit for non-Google actions. We use
4548
# https://github.com/sethvargo/ratchet to pin specific versions.
4649
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
50+
# Checkout XLA at head, if we're building jaxlib at head.
51+
- name: Checkout XLA at head
52+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
53+
if: ${{ matrix.jaxlib-version == 'head' }}
54+
with:
55+
repository: openxla/xla
56+
path: xla
4757
- name: Install JAX test requirements
4858
run: |
49-
pip install -U -r build/test-requirements.txt
50-
pip install -U -r build/collect-profile-requirements.txt
59+
$PYTHON -m pip install -U -r build/test-requirements.txt
60+
$PYTHON -m pip install -U -r build/collect-profile-requirements.txt
5161
- name: Install JAX
5262
run: |
53-
pip uninstall -y jax jaxlib libtpu
54-
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
55-
pip install .[tpu] \
63+
$PYTHON -m pip uninstall -y jax jaxlib libtpu
64+
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
65+
# Build and install jaxlib at head
66+
$PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67+
--bazel_options="--override_repository=xla=$(pwd)/xla" \
68+
--bazel_options=--color=yes
69+
$PYTHON -m pip install dist/*.whl
70+
71+
# Install "jax" at head
72+
$PYTHON -m pip install -U -e .
73+
74+
# Install libtpu
75+
$PYTHON -m pip install --pre libtpu \
76+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
77+
elif [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
78+
$PYTHON -m pip install .[tpu] \
5679
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5780
5881
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
59-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu \
82+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
83+
$PYTHON -m pip install --pre libtpu \
6184
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
62-
pip install requests
85+
$PYTHON -m pip install requests
6386
6487
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
65-
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
88+
$PYTHON -m pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6689
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
67-
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
90+
$PYTHON -m pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6891
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
69-
pip install requests
92+
$PYTHON -m pip install requests
7093
else
7194
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7295
exit 1
7396
fi
7497
75-
python3 -c 'import sys; print("python version:", sys.version)'
76-
python3 -c 'import jax; print("jax version:", jax.__version__)'
77-
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
78-
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
79-
python3 -c 'import jax; print("libtpu version:",
98+
$PYTHON -c 'import sys; print("python version:", sys.version)'
99+
$PYTHON -c 'import jax; print("jax version:", jax.__version__)'
100+
$PYTHON -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
101+
strings /usr/local/lib/"$PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
102+
$PYTHON -c 'import jax; print("libtpu version:",
80103
jax.lib.xla_bridge.get_backend().platform_version)'
81104
- name: Run tests
82105
env:
83106
JAX_PLATFORMS: tpu,cpu
84107
PY_COLORS: 1
85108
run: |
86109
# Run single-accelerator tests in parallel
87-
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
110+
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
88111
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
89112
--maxfail=20 -m "not multiaccelerator" tests examples
90113
# Run Pallas printing tests, which need to run with I/O capturing disabled.
91-
TPU_STDERR_LOG_LEVEL=0 python3 -m pytest -s \
114+
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
92115
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
93116
# Run multi-accelerator across all chips
94-
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
117+
$PYTHON -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
95118
- name: Send chat on failure
96119
# Don't notify when testing the workflow from a branch.
97120
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14
31+
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/wheel_win_x64.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
python -m pip install -r build/test-requirements.txt
4141
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
4242
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
43-
python.exe build\build.py `
43+
python.exe build\build.py build --wheels=jaxlib `
4444
--bazel_options=--color=yes `
4545
--bazel_options=--config=win_clang `
4646
--verbose

.github/workflows/windows_ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ jobs:
4949
python -m pip install -r build/test-requirements.txt
5050
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
5151
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
52-
python.exe build\build.py `
52+
python.exe build\build.py build --wheels=jaxlib `
5353
--bazel_options=--color=yes `
54-
--bazel_options=--config=win_clang
54+
--bazel_options=--config=win_clang `
55+
--verbose
5556
5657
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
5758
with:

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1313
## jax 0.4.36
1414

1515
* Breaking Changes
16+
* This release lands "stackless", an internal change to JAX's tracing
17+
machinery. We made trace dispatch purely a function of context rather than a
18+
function of both context and data. This let us delete a lot of machinery for
19+
managing data-dependent tracing: levels, sublevels, `post_process_call`,
20+
`new_base_main`, `custom_bind`, and so on. The change should only affect
21+
users that use JAX internals.
22+
23+
If you do use JAX internals then you may need to
24+
update your code (see
25+
https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
26+
for clues about how to do this). There might also be version skew
27+
issues with JAX libraries that do this. If you find this change breaks your
28+
non-JAX-internals-using code then try the
29+
`config.jax_data_dependent_tracing_fallback` flag as a workaround, and if
30+
you need help updating your code then please file a bug.
1631
* {func}`jax.experimental.jax2tf.convert` with `native_serialization=False`
1732
or with `enable_xla=False` have been deprecated since July 2024, with
1833
JAX version 0.4.31. Now we removed support for these use cases. `jax2tf`
@@ -40,6 +55,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4055
`platforms` instead.
4156
* Hashing of tracers, which has been deprecated since version 0.4.30, now
4257
results in a `TypeError`.
58+
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
59+
replaces previous build.py usage. Run `python build/build.py --help` for
60+
more details. Brief overview of the new subcommand options:
61+
* `build`: Builds JAX wheel packages. For e.g., `python build/build.py build --wheels=jaxlib,jax-cuda-pjrt`
62+
* `requirements_update`: Updates requirements_lock.txt files.
4363
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
4464
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
4565
on the function inputs.

benchmarks/shape_poly_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import jax
1919
from jax import core
20-
from jax._src.numpy import lax_numpy
2120
from jax import export
2221

2322
jax.config.parse_flags_with_absl()
@@ -76,7 +75,7 @@ def inequalities_slice(state):
7675
while state:
7776
for _ in range(30):
7877
a.scope._clear_caches()
79-
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
78+
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
8079
_ = 0 <= slice_size <= b
8180
_ = start >= 0
8281
_ = start + slice_size <= b

0 commit comments

Comments
 (0)