Skip to content

Commit e559d27

Browse files
authored
Merge branch 'jax-ml:main' into main
2 parents 5953bfc + a88486c commit e559d27

File tree

98 files changed

+2787
-5221
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+2787
-5221
lines changed

.bazelrc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm
257257

258258
# Mac Arm64 CI configs
259259
build:ci_darwin_arm64 --macos_minimum_os=11.0
260+
# Clang 19 requires `-Wno-error=c23-extensions` but this flag is not supported
261+
# on Apple Clang in XCode 16.0 so we suppress unknown warning option errors
262+
# on Mac CI builds.
263+
build:ci_darwin_arm64 --copt=-Wno-unknown-warning-option
260264
build:ci_darwin_arm64 --config=macos_cache_push
261265
build:ci_darwin_arm64 --verbose_failures=true
262266
build:ci_darwin_arm64 --color=yes

.github/ISSUE_TEMPLATE/bug-report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ body:
2424
2525
[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
2626
27-
[Raw report]: http://github.com/jax-ml/jax/issues/new
27+
[Raw report]: https://github.com/jax-ml/jax/issues/new?template=none
2828
- type: textarea
2929
attributes:
3030
label: Description

.readthedocs.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ build:
1313
post_checkout:
1414
# Skip building PRs unless tagged with the "documentation" label.
1515
- |
16-
if [ "$READTHEDOCS_VERSION_TYPE" = "external" ] && (curl -s "https://api.github.com/repos/jax-ml/jax/issues/$READTHEDOCS_VERSION/labels" | grep -vq "https://api.github.com/repos/jax-ml/jax/labels/documentation")
17-
then
18-
exit 183;
19-
fi
16+
[ "${READTHEDOCS_VERSION_TYPE}" != "external" ] && echo "Building latest" && exit 0
17+
(curl -sL https://api.github.com/repos/jax-ml/jax/issues/${READTHEDOCS_VERSION}/labels | grep -q "https://api.github.com/repos/jax-ml/jax/labels/documentation") && echo "Building PR with label" || exit 183
2018
2119
# Build documentation in the docs/ directory with Sphinx
2220
sphinx:

CHANGELOG.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,25 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2424
which was added temporarily in v0.4.36 to allow users to opt out of the
2525
new "stackless" tracing machinery.
2626
* Removed the `config.jax_eager_pmap` config option.
27+
* Disallow the calling of `lower` and `trace` AOT APIs on the result
28+
of `jax.jit` if there have been subsequent wrappers applied.
29+
Previously this worked, but silently ignored the wrappers.
30+
The workaround is to apply `jax.jit` last among the wrappers,
31+
and similarly for `jax.pmap`.
32+
See {jax-issue}`#27873`.
33+
* The `cuda12_pip` extra for `jax` has been removed; use `pip install jax[cuda12]`
34+
instead.
2735

2836
* Changes
2937
* The minimum CuDNN version is v9.8.
3038
* JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
3139
supported.
40+
* JAX package extras are now updated to use dash instead of underscore to
41+
align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]`
42+
to install JAX, run `pip install jax[cuda12-local]` instead.
43+
* {func}`jax.jit` now requires `fun` to be passed by position, and additional
44+
arguments to be passed by keyword. Doing otherwise will result in a
45+
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
3246

3347
* Deprecations
3448

@@ -45,10 +59,17 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4559
* The deprecated use of {func}`jax.ffi.ffi_call` with inline arguments is no
4660
longer supported. {func}`~jax.ffi.ffi_call` now unconditionally returns a
4761
callable.
62+
* `jax.dlpack.to_dlpack` has been deprecated. You can usually pass a JAX
63+
`Array` directly to the `from_dlpack` function of another framework. If you
64+
need the functionality of `to_dlpack`, use the `__dlpack__` attribute of an
65+
array.
66+
* `jax.lax.infeed`, `jax.lax.infeed_p`, `jax.lax.outfeed`, and
67+
`jax.lax.outfeed_p` are deprecated and will be removed in JAX v0.7.0.
4868
* Several previously-deprecated APIs have been removed, including:
4969
* From `jax.lib.xla_client`: `ArrayImpl`, `FftType`, `PaddingType`,
5070
`PrimitiveType`, `XlaBuilder`, `dtype_to_etype`,
51-
`ops`, `register_custom_call_target`, `shape_from_pyval`.
71+
`ops`, `register_custom_call_target`, `shape_from_pyval`, `Shape`,
72+
`XlaComputation`.
5273
* From `jax.lib.xla_extension`: `ArrayImpl`, `XlaRuntimeError`.
5374
* From `jax`: `jax.treedef_is_leaf`, `jax.tree_flatten`, `jax.tree_map`,
5475
`jax.tree_leaves`, `jax.tree_structure`, `jax.tree_transpose`, and
@@ -62,6 +83,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
6283
`raise_to_shaped_mappings`, `reset_trace_state`, `str_eqn_compact`,
6384
`substitute_vars_in_output_ty`, `typecompat`, and `used_axis_names_jaxpr`. Most
6485
have no public replacement, though a few are available at {mod}`jax.extend.core`.
86+
* The `vectorized` argument to {func}`~jax.pure_callback` and
87+
{func}`~jax.ffi.ffi_call`. Use the `vmap_method` parameter instead.
6588

6689
## jax 0.5.3 (Mar 19, 2025)
6790

docs/gpu_performance_tips.md

Lines changed: 221 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,6 @@ Run the real workflow, if you found these loggings in the running log, it means
243243

244244
By adjusting this factor, users can fine-tune the trade-off between memory efficiency
245245
and performance optimizations.
246-
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
247-
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
248-
i-th layer computation. It also enables overlapping (i+1)-th layer
249-
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
250-
value is False. **There are some bugs when this flag is turned on.**
251-
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
252-
performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a
253-
nonzero threshold decomposes `CollectivePermute`s into
254-
`CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that
255-
computation can be performed between each corresponding
256-
`ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the
257-
threshold is 0 and there is no decomposition. Setting it to threshold > 0 such
258-
as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this
259-
feature.
260246
* **--xla_gpu_all_gather_combine_threshold_bytes**
261247
**--xla_gpu_reduce_scatter_combine_threshold_bytes**
262248
**--xla_gpu_all_reduce_combine_threshold_bytes**
@@ -268,6 +254,227 @@ Run the real workflow, if you found these loggings in the running log, it means
268254
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
269255
default, the `combine_threshold_bytes` is set to 256.
270256

257+
### Pipeline Parallelism on GPU
258+
259+
XLA implements SPMD-based pipeline parallelism optimizations. This is a scaling technique
260+
where the forward and backward pass are split into multiple pipeline stages.
261+
Each device (or device group) processes the result of the previous
262+
pipeline stage (or the pipeline input) and sends its partial result to the next
263+
stage until the end of the pipeline is reached. This optimization works best
264+
when the latency of the computation is larger than communication. At compile
265+
time, the operations will be rearranged to overlap communication with
266+
computation.
267+
268+
For an optimized schedule, we recommend these XLA flags:
269+
```
270+
--xla_gpu_enable_latency_hiding_scheduler=true
271+
--xla_gpu_enable_command_buffer=''
272+
--xla_disable_hlo_passes=collective-permute-motion
273+
--xla_gpu_experimental_pipeline_parallelism_opt_level=PIPELINE_PARALLELISM_OPT_LEVEL_ENABLE
274+
```
275+
276+
The following JAX example demonstrates a pattern where communication operations
277+
are scheduled to overlap with computations. In this example we will illustrate
278+
how to set up an optimized pipeline parallelism scheduling using 4 GPUs that
279+
form a communication ring (device 0 -> device 1 -> device 2 -> device 3 ->
280+
device 0). We refer to the pattern `0 -> 1 -> 2 -> 3` as the forward edge, and
281+
`3 -> 0` as the back edge.
282+
283+
```
284+
# Imports and setup
285+
import functools
286+
import jax
287+
from jax import sharding
288+
from jax.experimental import mesh_utils
289+
import jax.numpy as jnp
290+
import jax.random
291+
292+
NUM_DEVICES = 4
293+
NUM_MICROBATCHES = 5
294+
NUM_CIRC_REPEATS = 2
295+
CONTRACTING_DIM_SIZE = 4096
296+
NON_CONTRACTING_DIM_SIZE = 8192
297+
COMPUTE_INTENSITY = 32
298+
299+
# Creates a collective permute for the "forward edge".
300+
# 0->1, 1->2, ... (N-2)->(N-1)
301+
def shift_right(arr):
302+
padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1)
303+
# Use lax.slice to guarantee the gradient is a pad.
304+
return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape)
305+
306+
307+
# Creates a collective permute for the "back edge".
308+
# (N-1)->0
309+
def cycle_back(arr):
310+
padding = [[0, NUM_DEVICES - 1]] + [[0, 0]] * (arr.ndim - 1)
311+
return jax.lax.slice(
312+
jnp.pad(arr, padding),
313+
[NUM_DEVICES - 1] + [0] * (arr.ndim - 1),
314+
(NUM_DEVICES - 1 + arr.shape[0],) + arr.shape[1:],
315+
)
316+
317+
318+
def select_on_first_device(then_value, else_value):
319+
assert then_value.shape == else_value.shape
320+
is_first_device = jax.lax.broadcasted_iota("int32", then_value.shape, 0) == 0
321+
return jnp.where(is_first_device, then_value, else_value)
322+
323+
324+
def select_on_last_device(then_value, else_value):
325+
assert then_value.shape == else_value.shape
326+
is_last_device = (
327+
jax.lax.broadcasted_iota("int32", then_value.shape, 0) == NUM_DEVICES - 1
328+
)
329+
return jnp.where(is_last_device, then_value, else_value)
330+
331+
332+
def select_on_first_cycle(i, then_value, else_value):
333+
assert then_value.shape == else_value.shape
334+
is_first_cycle = i < NUM_MICROBATCHES
335+
return jnp.where(is_first_cycle, then_value, else_value)
336+
337+
338+
def while_body(carry, i):
339+
"""Body of the pipeline while loop."""
340+
weights, input_buffer, output_buffer, fwd_edge_data, bwd_edge_data = carry
341+
342+
# Read input data from input buffer.
343+
input_data = jax.lax.dynamic_slice(
344+
input_buffer,
345+
(0, (i + 0) % NUM_MICROBATCHES, 0, 0),
346+
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
347+
)
348+
349+
# Collective permute on the "forward edge" shifts data to the next stage.
350+
fwd_edge_data = shift_right(fwd_edge_data)
351+
352+
# Select compute argument based on device and pipeline cycle.
353+
compute_argument = select_on_first_device(
354+
select_on_first_cycle(i, input_data, bwd_edge_data),
355+
fwd_edge_data,
356+
).reshape((NUM_DEVICES, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE))
357+
358+
# A few matmuls to simulate compute.
359+
tmp = compute_argument
360+
for _ in range(COMPUTE_INTENSITY):
361+
tmp = jax.lax.dot_general(weights, tmp, (((2,), (1,)), ((0,), (0,))))
362+
compute_result = tmp.reshape(
363+
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
364+
)
365+
366+
# Read data from buffer to pass it to the first device of the pipeline on the
367+
# "back edge".
368+
bwd_edge_data = jax.lax.dynamic_slice(
369+
output_buffer,
370+
(0, (1 + i) % NUM_MICROBATCHES, 0, 0),
371+
(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE),
372+
)
373+
374+
# Colelctive permute on the "back edge" passes data to the first device.
375+
bwd_edge_data = cycle_back(bwd_edge_data)
376+
377+
# Update output buffer. We do this after reading from it to avoid the data
378+
# dependency.
379+
output_buffer = jax.lax.dynamic_update_slice(
380+
output_buffer,
381+
compute_result,
382+
(0, (2 + i) % NUM_MICROBATCHES, 0, 0),
383+
)
384+
385+
fwd_edge_data = compute_result
386+
carry = (
387+
weights,
388+
input_buffer,
389+
output_buffer,
390+
fwd_edge_data,
391+
bwd_edge_data,
392+
)
393+
return carry, i
394+
395+
396+
@functools.partial(jax.jit, static_argnames=["mesh"])
397+
def entry_computation(weights, input_buffer, mesh):
398+
399+
# Init output buffer.
400+
output_buffer = jnp.zeros_like(input_buffer)
401+
402+
# Init dummy data for forward and backward edge passed through the while loop.
403+
dummy_data = jnp.zeros(
404+
shape=(NUM_DEVICES, 1, CONTRACTING_DIM_SIZE, NON_CONTRACTING_DIM_SIZE)
405+
).astype(jnp.float32)
406+
dummy_data = jax.device_put(
407+
dummy_data,
408+
sharding.NamedSharding(
409+
mesh, sharding.PartitionSpec("the_one_and_only_axis")
410+
),
411+
)
412+
413+
# Start pipeline.
414+
carry = weights, input_buffer, output_buffer, dummy_data, dummy_data
415+
num_iterations = NUM_CIRC_REPEATS * NUM_MICROBATCHES + NUM_DEVICES - 1
416+
carry, _ = jax.lax.scan(while_body, carry, xs=jnp.arange(num_iterations))
417+
_, _, output_buffer, _, _ = carry
418+
419+
return output_buffer
420+
421+
422+
def main(_):
423+
424+
# Expect constant number of devices.
425+
assert NUM_DEVICES == jax.local_device_count()
426+
427+
# Create mesh.
428+
mesh = sharding.Mesh(
429+
mesh_utils.create_device_mesh([NUM_DEVICES]),
430+
axis_names=["the_one_and_only_axis"],
431+
)
432+
433+
# Init weights.
434+
weights = 1.0 / CONTRACTING_DIM_SIZE
435+
weights = jax.lax.broadcast_in_dim(
436+
weights,
437+
shape=(NUM_DEVICES, CONTRACTING_DIM_SIZE, CONTRACTING_DIM_SIZE),
438+
broadcast_dimensions=(),
439+
)
440+
weights = jax.device_put(
441+
weights,
442+
sharding.NamedSharding(
443+
mesh, sharding.PartitionSpec("the_one_and_only_axis")
444+
),
445+
)
446+
447+
# Init random input and replicate it across all devices.
448+
random_key = jax.random.key(0)
449+
input_buffer = jax.random.uniform(
450+
random_key,
451+
shape=(
452+
NUM_MICROBATCHES,
453+
CONTRACTING_DIM_SIZE,
454+
NON_CONTRACTING_DIM_SIZE,
455+
),
456+
)
457+
input_buffer = jax.lax.broadcast_in_dim(
458+
input_buffer,
459+
shape=(
460+
NUM_DEVICES,
461+
NUM_MICROBATCHES,
462+
CONTRACTING_DIM_SIZE,
463+
NON_CONTRACTING_DIM_SIZE,
464+
),
465+
broadcast_dimensions=[1, 2, 3],
466+
)
467+
input_buffer = jax.device_put(
468+
input_buffer,
469+
sharding.NamedSharding(
470+
mesh, sharding.PartitionSpec("the_one_and_only_axis")
471+
),
472+
)
473+
474+
# Run computation.
475+
output_buffer = entry_computation(weights, input_buffer, mesh)
476+
print(f"output_buffer = \n{output_buffer}")
477+
```
271478
## NCCL flags
272479

273480
These Nvidia NCCL flag values may be useful for single-host multi-device

docs/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ pip install --upgrade pip
158158

159159
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
160160
# Note: wheels only available on linux.
161-
pip install --upgrade "jax[cuda12_local]"
161+
pip install --upgrade "jax[cuda12-local]"
162162
```
163163

164164
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
@@ -296,7 +296,7 @@ pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.co
296296
- NVIDIA GPU (CUDA 12):
297297

298298
```bash
299-
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
299+
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
300300
```
301301

302302
- NVIDIA GPU (CUDA 12) legacy:

0 commit comments

Comments
 (0)