Skip to content

Commit 633fcba

Browse files
committed
Merge branch 'main' of https://github.com/google/jax
2 parents 97c8d5d + 67f24df commit 633fcba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+2830
-1422
lines changed

.github/workflows/asan.yaml

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
name: CI - Address Sanitizer (nightly)
2+
3+
concurrency:
4+
group: ${{ github.workflow }}-${{ github.ref }}
5+
cancel-in-progress: true
6+
7+
on:
8+
schedule:
9+
- cron: "0 12 * * *" # Daily at 12:00 UTC
10+
workflow_dispatch: # allows triggering the workflow run manually
11+
pull_request: # Automatically trigger on pull requests affecting this file
12+
branches:
13+
- main
14+
paths:
15+
- '**/workflows/asan.yml'
16+
17+
jobs:
18+
asan:
19+
runs-on: linux-x86-n2-64
20+
container:
21+
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
22+
strategy:
23+
fail-fast: false
24+
defaults:
25+
run:
26+
shell: bash -l {0}
27+
steps:
28+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
29+
with:
30+
path: jax
31+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
32+
with:
33+
repository: python/cpython
34+
path: cpython
35+
ref: v3.12.6
36+
- name: Install clang 18
37+
env:
38+
DEBIAN_FRONTEND: noninteractive
39+
run: |
40+
apt update
41+
apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \
42+
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \
43+
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \
44+
libffi-dev liblzma-dev
45+
- name: Build CPython with ASAN enabled
46+
env:
47+
ASAN_OPTIONS: detect_leaks=0
48+
run: |
49+
cd cpython
50+
mkdir ${GITHUB_WORKSPACE}/cpythonasan
51+
CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc
52+
make -j64
53+
make install
54+
${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv
55+
- name: Install JAX test requirements
56+
env:
57+
ASAN_OPTIONS: detect_leaks=0
58+
run: |
59+
source ${GITHUB_WORKSPACE}/venv/bin/activate
60+
cd jax
61+
pip install -r build/test-requirements.txt
62+
- name: Build and install JAX
63+
env:
64+
ASAN_OPTIONS: detect_leaks=0
65+
run: |
66+
source ${GITHUB_WORKSPACE}/venv/bin/activate
67+
cd jax
68+
python build/build.py \
69+
--bazel_options=--color=yes \
70+
--bazel_options=--copt=-fsanitize=address \
71+
--clang_path=/usr/bin/clang-18
72+
pip install dist/jaxlib-*.whl
73+
pip install -e .
74+
- name: Run tests
75+
env:
76+
ASAN_OPTIONS: detect_leaks=0
77+
JAX_NUM_GENERATED_CASES: 1
78+
JAX_ENABLE_X64: true
79+
JAX_SKIP_SLOW_TESTS: true
80+
PY_COLORS: 1
81+
run: |
82+
source ${GITHUB_WORKSPACE}/venv/bin/activate
83+
cd jax
84+
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
85+
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
86+
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
87+
# The LD_PRELOAD works around https://github.com/google/sanitizers/issues/934#issuecomment-649516500
88+
LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 python -m pytest -n auto --tb=short --maxfail=20 tests

CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.34
13+
## jax 0.4.35
14+
15+
## jax 0.4.34 (October 4, 2023)
1416

1517
* New Functionality
1618
* This release includes wheels for Python 3.13. Free-threading mode is not yet
@@ -46,6 +48,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4648
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
4749
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
4850
`jax.errors.JaxRuntimeError` instead.
51+
* The default behavior of {func}`jax.pure_callback` and
52+
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
53+
the `vectorized` parameter to those functions. The `vmap_method` parameter
54+
should be used instead for better defined behavior. See the discussion in
55+
{jax-issue}`#23881` for more details.
4956

5057
* Deletion:
5158
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
# Transformable numerical computing at scale
1010

11-
![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)
12-
![PyPI version](https://img.shields.io/pypi/v/jax)
11+
[![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg)](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml)
12+
[![PyPI version](https://img.shields.io/pypi/v/jax)](https://pypi.org/project/jax/)
1313

1414
[**Quickstart**](#quickstart-colab-in-the-cloud)
1515
| [**Transformations**](#transformations)

docs/ffi.ipynb

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@
303303
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
304304
" # static parameter (i.e. not a JAX array).\n",
305305
" eps=np.float32(eps),\n",
306-
" # The `vectorized` parameter controls this function's behavior under `vmap`\n",
306+
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
307307
" # as discussed below.\n",
308-
" vectorized=True,\n",
308+
" vmap_method=\"broadcast_fullrank\",\n",
309309
" )\n",
310310
"\n",
311311
"\n",
@@ -325,7 +325,7 @@
325325
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
326326
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
327327
"\n",
328-
"The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
328+
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
329329
"\n",
330330
"```{tip}\n",
331331
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
@@ -336,19 +336,29 @@
336336
"(ffi-call-vmap)=\n",
337337
"### Batching with `vmap`\n",
338338
"\n",
339-
"All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n",
340-
"By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
341-
"This default implementation is general purpose, but it doesn't parallelize very well.\n",
342-
"But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n",
339+
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
340+
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
343341
"\n",
344-
"The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n",
342+
"The simplest `vmap_method` is `\"sequential\"`.\n",
343+
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
344+
"This implementation is general purpose, but it doesn't parallelize very well.\n",
345+
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
346+
"\n",
347+
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
348+
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
345349
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
346350
"\n",
347351
"```python\n",
348352
"ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n",
349353
"```\n",
350354
"\n",
351-
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:"
355+
"```{tip}\n",
356+
"Note that things get a bit more complicated when we have multiple input arguments.\n",
357+
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
358+
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
359+
"```\n",
360+
"\n",
361+
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
352362
]
353363
},
354364
{
@@ -380,7 +390,7 @@
380390
"cell_type": "markdown",
381391
"metadata": {},
382392
"source": [
383-
"If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
393+
"Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
384394
]
385395
},
386396
{
@@ -389,24 +399,24 @@
389399
"metadata": {},
390400
"outputs": [],
391401
"source": [
392-
"def rms_norm_not_vectorized(x, eps=1e-5):\n",
402+
"def rms_norm_sequential(x, eps=1e-5):\n",
393403
" return jex.ffi.ffi_call(\n",
394404
" \"rms_norm\",\n",
395405
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
396406
" x,\n",
397407
" eps=np.float32(eps),\n",
398-
" vectorized=False, # This is the default behavior\n",
408+
" vmap_method=\"sequential\",\n",
399409
" )\n",
400410
"\n",
401411
"\n",
402-
"jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)"
412+
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
403413
]
404414
},
405415
{
406416
"cell_type": "markdown",
407417
"metadata": {},
408418
"source": [
409-
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
419+
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
410420
]
411421
},
412422
{
@@ -454,7 +464,7 @@
454464
" ),\n",
455465
" x,\n",
456466
" eps=np.float32(eps),\n",
457-
" vectorized=True,\n",
467+
" vmap_method=\"broadcast_fullrank\",\n",
458468
" )\n",
459469
" return y, (res, x)\n",
460470
"\n",
@@ -471,7 +481,7 @@
471481
" res,\n",
472482
" x,\n",
473483
" ct,\n",
474-
" vectorized=True,\n",
484+
" vmap_method=\"broadcast_fullrank\",\n",
475485
" ),\n",
476486
" )\n",
477487
"\n",
@@ -561,7 +571,7 @@
561571
" out_type,\n",
562572
" x,\n",
563573
" eps=np.float32(eps),\n",
564-
" vectorized=True,\n",
574+
" vmap_method=\"broadcast_fullrank\",\n",
565575
" )\n",
566576
"\n",
567577
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",

docs/ffi.md

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5):
264264
# type (which corresponds to numpy's `float32` type), and it must be a
265265
# static parameter (i.e. not a JAX array).
266266
eps=np.float32(eps),
267-
# The `vectorized` parameter controls this function's behavior under `vmap`
267+
# The `vmap_method` parameter controls this function's behavior under `vmap`
268268
# as discussed below.
269-
vectorized=True,
269+
vmap_method="broadcast_fullrank",
270270
)
271271
272272
@@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal
282282
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
283283
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
284284

285-
The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
285+
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
286286

287287
```{tip}
288288
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
@@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so
293293
(ffi-call-vmap)=
294294
### Batching with `vmap`
295295

296-
All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.
297-
By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
298-
This default implementation is general purpose, but it doesn't parallelize very well.
299-
But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.
296+
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
297+
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
300298

301-
The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.
299+
The simplest `vmap_method` is `"sequential"`.
300+
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
301+
This implementation is general purpose, but it doesn't parallelize very well.
302+
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
303+
304+
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
305+
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
302306
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
303307

304308
```python
305309
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
306310
```
307311

308-
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:
312+
```{tip}
313+
Note that things get a bit more complicated when we have multiple input arguments.
314+
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
315+
The documentation for {func}`~jax.pure_callback` includes some examples of this
316+
```
317+
318+
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
309319

310320
```{code-cell} ipython3
311321
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
@@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms
317327
jax.make_jaxpr(jax.vmap(rms_norm))(x)
318328
```
319329

320-
If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
330+
Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
321331

322332
```{code-cell} ipython3
323-
def rms_norm_not_vectorized(x, eps=1e-5):
333+
def rms_norm_sequential(x, eps=1e-5):
324334
return jex.ffi.ffi_call(
325335
"rms_norm",
326336
jax.ShapeDtypeStruct(x.shape, x.dtype),
327337
x,
328338
eps=np.float32(eps),
329-
vectorized=False, # This is the default behavior
339+
vmap_method="sequential",
330340
)
331341
332342
333-
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
343+
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
334344
```
335345

336-
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
346+
If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
337347

338348
+++
339349

@@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5):
372382
),
373383
x,
374384
eps=np.float32(eps),
375-
vectorized=True,
385+
vmap_method="broadcast_fullrank",
376386
)
377387
return y, (res, x)
378388
@@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct):
389399
res,
390400
x,
391401
ct,
392-
vectorized=True,
402+
vmap_method="broadcast_fullrank",
393403
),
394404
)
395405
@@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
469479
out_type,
470480
x,
471481
eps=np.float32(eps),
472-
vectorized=True,
482+
vmap_method="broadcast_fullrank",
473483
)
474484
475485
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))

docs/jax.lax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,5 @@ Argument classes
255255
.. autoclass:: Precision
256256
.. autoclass:: PrecisionLike
257257
.. autoclass:: RoundingMethod
258+
:members:
258259
.. autoclass:: ScatterDimensionNumbers

docs/jax.numpy.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ namespace; they are listed below.
376376
size
377377
sort
378378
sort_complex
379+
spacing
379380
split
380381
sqrt
381382
square

docs/notebooks/shard_map.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,8 @@
510510
"the corresponding `PartitionSpec` `spec` as roughly\n",
511511
"`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n",
512512
"\n",
513+
"(shard_map_collectives_tutorial)=\n",
514+
"\n",
513515
"## Collectives tutorial\n",
514516
"\n",
515517
"A `shard_map` need not be a pure map: function applications can communicate\n",

docs/notebooks/shard_map.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and
357357
the corresponding `PartitionSpec` `spec` as roughly
358358
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.
359359

360+
(shard_map_collectives_tutorial)=
361+
360362
## Collectives tutorial
361363

362364
A `shard_map` need not be a pure map: function applications can communicate

docs/pallas/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ Remember to align the itemized text with the first line of an item within a list
1818
* {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
1919
to be scalars. The restrictions on the arguments are backend-specific:
2020
Non-scalar arguments are currently only supported on GPU, when using Triton.
21+
* {class}`jax.experimental.pallas.BlockSpec` no longer supports the previously
22+
deprecated argument order, where `index_map` comes before `block_shape`.
2123

2224
* Deprecations
2325

0 commit comments

Comments
 (0)