Skip to content

Commit 402814b

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 5a5e219 + 302d803 commit 402814b

File tree

132 files changed

+3079
-1497
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

132 files changed

+3079
-1497
lines changed

.bazelrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX
9696

9797
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
9898

99+
# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL).
100+
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
101+
build:mkl_aarch64_threadpool --@compute_library//:openmp=false
102+
build:mkl_aarch64_threadpool -c opt
103+
99104
# Disable clang extention that rejects type definitions within offsetof.
100105
# This was added in clang-16 by https://reviews.llvm.org/D133574.
101106
# Can be removed once upb is updated, since a type definition is used within

.github/workflows/asan.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ on:
1616

1717
jobs:
1818
asan:
19+
# Don't execute in fork due to runner type
20+
if: github.repository == 'jax-ml/jax'
1921
runs-on: linux-x86-n2-64
2022
container:
2123
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04

.github/workflows/ci-build.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ jobs:
4242
- run: pre-commit run --show-diff-on-failure --color=always --all-files
4343

4444
build:
45+
# Don't execute in fork due to runner type
46+
if: github.repository == 'jax-ml/jax'
4547
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
4648
runs-on: linux-x86-n2-32
4749
container:

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
name: CI - Cloud TPU (nightly)
1414
on:
1515
schedule:
16-
- cron: "0 */2 * * *" # Run every 2 hours
16+
- cron: "0 2,14 * * *" # Run at 7am and 7pm PST
1717
workflow_dispatch: # allows triggering the workflow run manually
1818
# This should also be set to read-only in the project settings, but it's nice to
1919
# document and enforce the permissions here.
@@ -33,12 +33,12 @@ jobs:
3333
python-version: ["3.10"]
3434
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
3535
env:
36-
LIBTPU_OLDEST_VERSION_DATE: 20240722
36+
LIBTPU_OLDEST_VERSION_DATE: 20240922
3737
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
3838
PYTHON: python${{ matrix.python-version }}
3939
runs-on: ${{ matrix.tpu.runner }}
4040
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
41-
timeout-minutes: 120
41+
timeout-minutes: 180
4242
defaults:
4343
run:
4444
shell: bash -ex {0}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Cloud TPU CI (presubmit)
2+
#
3+
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
4+
# tested to get to a stable state before we enable it as a blocking presubmit.
5+
name: CI - Cloud TPU (presubmit)
6+
on:
7+
workflow_dispatch:
8+
inputs:
9+
halt-for-connection:
10+
description: 'Should this workflow run wait for a remote connection?'
11+
type: choice
12+
required: true
13+
default: 'no'
14+
options:
15+
- 'yes'
16+
- 'no'
17+
pull_request:
18+
branches:
19+
- main
20+
21+
# This should also be set to read-only in the project settings, but it's nice to
22+
# document and enforce the permissions here.
23+
permissions:
24+
contents: read
25+
26+
concurrency:
27+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
28+
cancel-in-progress: true
29+
30+
jobs:
31+
cloud-tpu-test:
32+
if: github.event.repository.fork == false
33+
strategy:
34+
fail-fast: false # don't cancel all jobs on failure
35+
matrix:
36+
tpu: [
37+
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
38+
]
39+
python-version: ["3.10"]
40+
41+
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
42+
43+
env:
44+
JAXCI_PYTHON: python${{ matrix.python-version }}
45+
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
46+
47+
runs-on: ${{ matrix.tpu.runner }}
48+
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
49+
50+
timeout-minutes: 60
51+
52+
defaults:
53+
run:
54+
shell: bash -ex {0}
55+
56+
steps:
57+
# https://opensource.google/documentation/reference/github/services#actions
58+
# mandates using a specific commit for non-Google actions. We use
59+
# https://github.com/sethvargo/ratchet to pin specific versions.
60+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
61+
# Checkout XLA at head, if we're building jaxlib at head.
62+
- name: Checkout XLA at head
63+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
64+
with:
65+
repository: openxla/xla
66+
path: xla
67+
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
68+
- name: Mark GitHub workspace as safe
69+
run: |
70+
git config --global --add safe.directory "$GITHUB_WORKSPACE"
71+
- name: Install JAX test requirements
72+
run: |
73+
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
74+
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
75+
- name: Build jaxlib at head with latest XLA
76+
run: |
77+
# Build and install jaxlib at head
78+
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
79+
--python_version=${{ matrix.python-version }} \
80+
--bazel_options=--config=rbe_linux_x86_64 \
81+
--local_xla_path="$(pwd)/xla" \
82+
--verbose
83+
84+
# Install libtpu
85+
$JAXCI_PYTHON -m pip install --pre libtpu \
86+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
87+
# Halt for testing
88+
- name: Wait For Connection
89+
uses: google-ml-infra/actions/ci_connection@main
90+
with:
91+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
92+
- name: Install jaxlib wheel and run tests
93+
run: ./ci/run_pytest_tpu.sh

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ repos:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)
3838
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
39-
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0]
39+
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
4040
args: [--config=pyproject.toml]
4141

4242
- repo: https://github.com/mwouts/jupytext

CHANGELOG.md

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,26 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## jax 0.4.38
1414

15+
* Changes:
16+
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
17+
as shortcuts of the corresponding `tree_util` functions.
18+
1519
* Deprecations
16-
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
17-
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
18-
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
19-
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
20-
{mod}`jax.extend` for information on the compatibility guarantees of these
21-
semi-public extensions.
20+
* a number of APIs in the internal `jax.core` namespace have been deprecated.
21+
Most were no-ops, were little-used, or can be replaced by APIs of the same
22+
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
23+
for information on the compatibility guarantees of these semi-public extensions.
24+
* Several previously-deprecated APIs have been removed, including:
25+
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
26+
`non_negative_dim`.
27+
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
28+
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
29+
* from {mod}`jax.numpy`: `round_`.
30+
31+
* New Features
32+
* {func}`jax.export.export` can be used for device-polymorphic export with
33+
shardings constructed with {func}`jax.sharding.AbstractMesh`.
34+
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
2235

2336
## jax 0.4.37 (Dec 9, 2024)
2437

build/build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,10 @@ async def main():
485485

486486
if not args.disable_mkl_dnn:
487487
logging.debug("Enabling MKL DNN")
488-
wheel_build_command.append("--config=mkl_open_source_only")
488+
if target_cpu == "aarch64":
489+
wheel_build_command.append("--config=mkl_aarch64_threadpool")
490+
else:
491+
wheel_build_command.append("--config=mkl_open_source_only")
489492

490493
if args.target_cpu_features == "release":
491494
if arch in ["x86_64", "AMD64"]:

ci/run_pytest_tpu.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,21 @@ source "ci/utilities/setup_build_environment.sh"
4040
strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
4141
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4242

43-
# Set up common test environment variables
43+
# Set up all common test environment variables
4444
export PY_COLORS=1
45-
export JAX_SKIP_SLOW_TESTS=true
4645
export JAX_PLATFORMS=tpu,cpu
46+
export JAX_SKIP_SLOW_TESTS=true
4747
# End of common test environment variable setup
4848

4949
echo "Running TPU tests..."
50+
5051
# Run single-accelerator tests in parallel
5152
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
5253
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
53-
--maxfail=20 -m "not multiaccelerator" tests examples
54+
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py
5455

5556
# Run Pallas printing tests, which need to run with I/O capturing disabled.
5657
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
5758

5859
# Run multi-accelerator across all chips
59-
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
60+
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py

docs/export/export.md

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ present on the exporting machine:
240240

241241
```
242242

243-
There is a safety check that will be raise an error when trying to compile
243+
There is a safety check that will raise an error when trying to compile
244244
an `Exported` object on a machine that does not have the accelerator
245245
for which the code was exported.
246246

@@ -326,7 +326,7 @@ combinations of input shapes.
326326

327327
See the {ref}`shape_poly` documentation.
328328

329-
## Device polymorphic export
329+
## Device-polymorphic export
330330

331331
An exported artifact may contain sharding annotations for inputs,
332332
outputs and for some intermediates, but these annotations do not refer
@@ -335,20 +335,28 @@ Instead, the sharding annotations refer to logical devices. This
335335
means that you can compile and run the exported artifacts on different
336336
physical devices that were used for exporting.
337337

338+
The cleanest way to achieve a device-polymorphic export is to
339+
use shardings constructed with a `jax.sharding.AbstractMesh`,
340+
which contains only the mesh shape and axis names. But,
341+
you can achieve the same results if you use shardings
342+
constructed for a mesh with concrete devices, since the actual
343+
devices in the mesh are ignored for tracing and lowering:
344+
338345
```python
339346
>>> import jax
340347
>>> from jax import export
341-
>>> from jax.sharding import Mesh, NamedSharding
348+
>>> from jax.sharding import AbstractMesh, Mesh, NamedSharding
342349
>>> from jax.sharding import PartitionSpec as P
350+
>>>
351+
>>> # Use an AbstractMesh for exporting
352+
>>> export_mesh = AbstractMesh((("a", 4),))
343353

344-
>>> # Use the first 4 devices for exporting.
345-
>>> export_devices = jax.local_devices()[:4]
346-
>>> export_mesh = Mesh(export_devices, ("a",))
347354
>>> def f(x):
348355
... return x.T
349356

350-
>>> arg = jnp.arange(8 * len(export_devices))
351-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
357+
>>> exp = export.export(jax.jit(f))(
358+
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
359+
... sharding=NamedSharding(export_mesh, P("a"))))
352360

353361
>>> # `exp` knows for how many devices it was exported.
354362
>>> exp.nr_devices
@@ -359,8 +367,20 @@ physical devices that were used for exporting.
359367
>>> exp.in_shardings_hlo
360368
({devices=[4]<=[4]},)
361369

370+
>>> # You can also use a concrete set of devices for exporting
371+
>>> concrete_devices = jax.local_devices()[:4]
372+
>>> concrete_mesh = Mesh(concrete_devices, ("a",))
373+
>>> exp2 = export.export(jax.jit(f))(
374+
... jax.ShapeDtypeStruct((32,), dtype=np.int32,
375+
... sharding=NamedSharding(concrete_mesh, P("a"))))
376+
377+
>>> # You can expect the same results
378+
>>> assert exp.in_shardings_hlo == exp2.in_shardings_hlo
379+
380+
>>> # When you call an Exported, you must use a concrete set of devices
381+
>>> arg = jnp.arange(8 * 4)
362382
>>> res1 = exp.call(jax.device_put(arg,
363-
... NamedSharding(export_mesh, P("a"))))
383+
... NamedSharding(concrete_mesh, P("a"))))
364384

365385
>>> # Check out the first 2 shards of the result
366386
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
@@ -397,9 +417,11 @@ of devices than it was exported for:
397417
>>> def f(x):
398418
... return x.T
399419

400-
>>> arg = jnp.arange(4 * len(export_devices))
401-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
420+
>>> exp = export.export(jax.jit(f))(
421+
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
422+
... sharding=NamedSharding(export_mesh, P("a"))))
402423

424+
>>> arg = jnp.arange(4 * len(export_devices))
403425
>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
404426
Traceback (most recent call last):
405427
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
@@ -420,13 +442,16 @@ artifacts using a new mesh constructed at the call site:
420442
>>> def f(x):
421443
... return x.T
422444

423-
>>> arg = jnp.arange(4 * len(export_devices))
424-
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
445+
446+
>>> exp = export.export(jax.jit(f))(
447+
... jax.ShapeDtypeStruct((4 * len(export_devices),), dtype=np.int32,
448+
... sharding=NamedSharding(export_mesh, P("a"))))
425449

426450
>>> # Prepare the mesh for calling `exp`.
427451
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))
428452

429453
>>> # Shard the arg according to what `exp` expects.
454+
>>> arg = jnp.arange(4 * len(export_devices))
430455
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
431456
>>> res = exp.call(sharded_arg)
432457

0 commit comments

Comments
 (0)