Skip to content

Commit 0bfc868

Browse files
committed
Merge branch 'main' of https://github.com/google-ml-infra/jax-fork into srnitin/task-jax-ci-rework
2 parents faa98f3 + fd7a21e commit 0bfc868

File tree

123 files changed

+5098
-2103
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

123 files changed

+5098
-2103
lines changed

.github/workflows/ci-build.yaml

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ permissions:
2020
contents: read # to fetch code
2121
actions: write # to cancel previous workflows
2222

23+
concurrency:
24+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
25+
cancel-in-progress: true
26+
2327
jobs:
2428
lint_and_typecheck:
2529
runs-on: ubuntu-latest
2630
timeout-minutes: 5
2731
steps:
28-
- name: Cancel previous
29-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
30-
with:
31-
access_token: ${{ github.token }}
32-
if: ${{github.ref != 'refs/heads/main'}}
3332
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
3433
- name: Set up Python 3.11
3534
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -58,11 +57,6 @@ jobs:
5857
prng-upgrade: 0
5958
num_generated_cases: 1
6059
steps:
61-
- name: Cancel previous
62-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
63-
with:
64-
access_token: ${{ github.token }}
65-
if: ${{github.ref != 'refs/heads/main'}}
6660
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
6761
- name: Set up Python ${{ matrix.python-version }}
6862
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -110,11 +104,6 @@ jobs:
110104
matrix:
111105
python-version: ['3.10']
112106
steps:
113-
- name: Cancel previous
114-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
115-
with:
116-
access_token: ${{ github.token }}
117-
if: ${{github.ref != 'refs/heads/main'}}
118107
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
119108
- name: Set up Python ${{ matrix.python-version }}
120109
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -152,11 +141,6 @@ jobs:
152141
matrix:
153142
python-version: ['3.10']
154143
steps:
155-
- name: Cancel previous
156-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
157-
with:
158-
access_token: ${{ github.token }}
159-
if: ${{github.ref != 'refs/heads/main'}}
160144
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
161145
- name: Set up Python ${{ matrix.python-version }}
162146
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
@@ -193,11 +177,6 @@ jobs:
193177
enable-x64: 0
194178
num_generated_cases: 10
195179
steps:
196-
- name: Cancel previous
197-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
198-
with:
199-
access_token: ${{ github.token }}
200-
if: ${{github.ref != 'refs/heads/main'}}
201180
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
202181
- name: Set up Python ${{ matrix.python-version }}
203182
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5

.github/workflows/jax-array-api.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ on:
99
- '**workflows/jax-array-api.yml'
1010
- '**experimental/array_api/**'
1111

12+
concurrency:
13+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
14+
cancel-in-progress: true
15+
1216
jobs:
1317
build:
1418

@@ -25,7 +29,7 @@ jobs:
2529
with:
2630
repository: data-apis/array-api-tests
2731
# TODO(jakevdp) update this to a stable release/tag when available.
28-
ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08
32+
ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09
2933
submodules: 'true'
3034
path: 'array-api-tests'
3135
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/metal_plugin_ci.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ on:
1111
paths:
1212
- '**workflows/metal_plugin_ci.yml'
1313

14+
concurrency:
15+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
16+
cancel-in-progress: true
17+
1418
jobs:
1519
jax-metal-plugin-test:
1620

.github/workflows/wheel_win_x64.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ name: Wheel build - Windows CPU x86_64
22
on:
33
workflow_dispatch: # allows triggering the workflow run manually
44

5+
concurrency:
6+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
7+
cancel-in-progress: true
8+
59
env:
610
DISTUTILS_USE_SDK: 1
711
MSSdk: 1
@@ -18,11 +22,6 @@ jobs:
1822
runs-on: ${{ matrix.os }}
1923

2024
steps:
21-
- name: Cancel Previous Runs
22-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
23-
with:
24-
access_token: ${{ github.token }}
25-
2625
- name: Install LLVM/Clang
2726
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
2827

@@ -58,7 +57,7 @@ jobs:
5857
JAX_SKIP_SLOW_TESTS: true
5958
PY_COLORS: 1
6059
run: |
60+
python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
6161
python -m pip install -e ${{ github.workspace }}
62-
python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib
6362
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
6463
pytest -n auto --tb=short tests examples

.github/workflows/windows_ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ on:
66
pull_request:
77
types: [ labeled ] # allow force-windows-run label
88

9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
11+
cancel-in-progress: true
12+
913
env:
1014
DISTUTILS_USE_SDK: 1
1115
MSSdk: 1
@@ -23,10 +27,6 @@ jobs:
2327
runs-on: ${{ matrix.os }}
2428

2529
steps:
26-
- name: Cancel Previous Runs
27-
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
28-
with:
29-
access_token: ${{ github.token }}
3030

3131
- name: Install LLVM/Clang
3232
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade

CHANGELOG.md

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,54 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.32
13+
## jax 0.4.34
14+
15+
* Deletion:
16+
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
17+
in 0.4.30 JAX release.
18+
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
19+
* `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
20+
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
21+
* You can also use `.out_info` property of `jax.stages.Lowered` to get the
22+
output information (like tree structure, shape and dtype).
23+
* For cross-backend lowering, you can replace
24+
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
25+
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
26+
* {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument.
27+
The argument was only used by `xmap` which was removed in 0.4.31.
28+
29+
30+
## jax 0.4.33 (September 16, 2024)
31+
32+
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
33+
release.
34+
35+
A TPU-only data corruption bug was found in the version of libtpu pinned by
36+
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
37+
same job, for example, if training on multiple v5e slices.
38+
This release fixes that issue by pinning a fixed version of `libtpu`.
39+
40+
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
41+
42+
## jax 0.4.32 (September 11, 2024)
43+
44+
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
45+
See the 0.4.33 release notes for more details.
46+
47+
* New Functionality
48+
* Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering`
49+
to support the use of the new {ref}`ffi-tutorial` to interface with custom
50+
C++ and CUDA code from JAX.
1451

1552
* Changes
53+
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
54+
* array[0] on a pmap result now introduces a reshape (use array[0:1]
55+
instead).
56+
* The per-shard shape (accessable via jax_array.addressable_shards or
57+
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
58+
that directly accesses shards accordingly. The rank of the per-shard-shape
59+
now matches that of the global shape which is the same behavior as jit.
60+
This avoids costly reshapes when passing results from pmap into jit.
1661
* `jax_enable_memories` flag is set to `True` by default.
1762
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
1863
See {ref}`python-array-api` for more information.
@@ -60,7 +105,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
60105
The argument to {func}`jax.dlpack.from_dlpack` should be an array from
61106
another framework that implements the ``__dlpack__`` protocol.
62107

63-
## jaxlib 0.4.32
108+
## jaxlib 0.4.32 (September 11, 2024)
109+
110+
Note: This release was yanked from PyPi because of a data corruption bug on TPU.
111+
See the 0.4.33 release notes for more details.
64112

65113
* Breaking changes
66114
* Hermetic CUDA support is added.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" alt="logo"></img>
77
</div>
88

9-
# Scalable, transformable, high-performance machine learning
9+
# Transformable numerical computing at scale
1010

1111
![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg)
1212
![PyPI version](https://img.shields.io/pypi/v/jax)

WORKSPACE

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ install_deps()
3737
load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter")
3838
custom_python_interpreter(
3939
name = "python_dev",
40-
urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"],
41-
strip_prefix = "Python-{version}",
42-
version = "3.13.0a6",
40+
urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"],
41+
strip_prefix = "Python-{version_variant}",
42+
version = "3.13.0",
43+
version_variant = "3.13.0rc2",
4344
)
4445

4546
load("@xla//:workspace4.bzl", "xla_workspace4")

build/requirements.in

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ matplotlib; python_version>="3.11"
1111
#
1212
# build deps
1313
#
14-
numpy~=2.0.0
14+
numpy~=2.0.0; python_version<="3.12"
15+
numpy~=2.1.0; python_version>="3.13"
1516

1617
#
1718
# runtime deps
1819
#
19-
scipy~=1.13.1
20+
scipy>=1.13.1
2021

2122
ml_dtypes>=0.4.0
2223
opt_einsum

0 commit comments

Comments
 (0)