@@ -10,9 +10,54 @@ Remember to align the itemized text with the first line of an item within a list
1010When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111-->
1212
13- ## jax 0.4.32
13+ ## jax 0.4.34
14+
15+ * Deletion:
16+ * ` jax.xla_computation ` is deleted. It's been 3 months since it's deprecation
17+ in 0.4.30 JAX release.
18+ Please use the AOT APIs to get the same functionality as ` jax.xla_computation ` .
19+ * ` jax.xla_computation(fn)(*args, **kwargs) ` can be replaced with
20+ ` jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo') ` .
21+ * You can also use ` .out_info ` property of ` jax.stages.Lowered ` to get the
22+ output information (like tree structure, shape and dtype).
23+ * For cross-backend lowering, you can replace
24+ ` jax.xla_computation(fn, backend='tpu')(*args, **kwargs) ` with
25+ ` jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo') ` .
26+ * {class}` jax.ShapeDtypeStruct ` no longer accepts the ` named_shape ` argument.
27+ The argument was only used by ` xmap ` which was removed in 0.4.31.
28+
29+
30+ ## jax 0.4.33 (September 16, 2024)
31+
32+ This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
33+ release.
34+
35+ A TPU-only data corruption bug was found in the version of libtpu pinned by
36+ JAX 0.4.32, which manifested only if multiple TPU slices were present in the
37+ same job, for example, if training on multiple v5e slices.
38+ This release fixes that issue by pinning a fixed version of ` libtpu ` .
39+
40+ This release fixes an inaccurate result for F64 tanh on CPU (#23590 ).
41+
42+ ## jax 0.4.32 (September 11, 2024)
43+
44+ Note: This release was yanked from PyPi because of a data corruption bug on TPU.
45+ See the 0.4.33 release notes for more details.
46+
47+ * New Functionality
48+ * Added {func}` jax.extend.ffi.ffi_call ` and {func}` jax.extend.ffi.ffi_lowering `
49+ to support the use of the new {ref}` ffi-tutorial ` to interface with custom
50+ C++ and CUDA code from JAX.
1451
1552* Changes
53+ * ` jax_pmap_no_rank_reduction ` flag is set to ` True ` by default.
54+ * array[ 0] on a pmap result now introduces a reshape (use array[ 0:1]
55+ instead).
56+ * The per-shard shape (accessable via jax_array.addressable_shards or
57+ jax_array.addressable_data(0)) now has a leading (1, ...). Update code
58+ that directly accesses shards accordingly. The rank of the per-shard-shape
59+ now matches that of the global shape which is the same behavior as jit.
60+ This avoids costly reshapes when passing results from pmap into jit.
1661 * ` jax_enable_memories ` flag is set to ` True ` by default.
1762 * {mod}` jax.numpy ` now supports v2023.12 of the Python Array API Standard.
1863 See {ref}` python-array-api ` for more information.
@@ -60,7 +105,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
60105 The argument to {func}` jax.dlpack.from_dlpack ` should be an array from
61106 another framework that implements the `` __dlpack__ `` protocol.
62107
63- ## jaxlib 0.4.32
108+ ## jaxlib 0.4.32 (September 11, 2024)
109+
110+ Note: This release was yanked from PyPi because of a data corruption bug on TPU.
111+ See the 0.4.33 release notes for more details.
64112
65113* Breaking changes
66114 * Hermetic CUDA support is added.
0 commit comments