Skip to content

Commit f90bbfa

Browse files
committed
Merge branch 'main' of https://github.com/google/jax
2 parents b8372b7 + 4688da3 commit f90bbfa

File tree

257 files changed

+10816
-3856
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

257 files changed

+10816
-3856
lines changed

.bazelrc

Lines changed: 240 additions & 183 deletions
Large diffs are not rendered by default.

.github/workflows/ci-build.yaml

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,31 @@ jobs:
3737
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
3838

3939
build:
40-
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
41-
runs-on: ${{ matrix.os }}
40+
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
41+
runs-on: linux-x86-n2-32
42+
container:
43+
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
4244
timeout-minutes: 60
4345
strategy:
4446
matrix:
4547
# Test the oldest and newest supported Python versions here.
4648
include:
4749
- name-prefix: "with 3.10"
4850
python-version: "3.10"
49-
os: ubuntu-20.04-16core
5051
enable-x64: 1
5152
prng-upgrade: 1
5253
num_generated_cases: 1
53-
- name-prefix: "with 3.12"
54-
python-version: "3.12"
55-
os: ubuntu-20.04-16core
54+
- name-prefix: "with 3.13"
55+
python-version: "3.13"
5656
enable-x64: 0
5757
prng-upgrade: 0
5858
num_generated_cases: 1
5959
steps:
6060
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
61+
- name: Image Setup
62+
run: |
63+
apt update
64+
apt install -y libssl-dev
6165
- name: Set up Python ${{ matrix.python-version }}
6266
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
6367
with:
@@ -68,7 +72,7 @@ jobs:
6872
python -m pip install --upgrade pip wheel
6973
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
7074
- name: pip cache
71-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
75+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
7276
with:
7377
path: ${{ steps.pip-cache.outputs.dir }}
7478
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -115,7 +119,7 @@ jobs:
115119
python -m pip install --upgrade pip wheel
116120
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
117121
- name: pip cache
118-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
122+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
119123
with:
120124
path: ${{ steps.pip-cache.outputs.dir }}
121125
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -152,7 +156,7 @@ jobs:
152156
python -m pip install --upgrade pip wheel
153157
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
154158
- name: pip cache
155-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
159+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
156160
with:
157161
path: ${{ steps.pip-cache.outputs.dir }}
158162
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -188,7 +192,7 @@ jobs:
188192
python -m pip install --upgrade pip wheel
189193
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
190194
- name: pip cache
191-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
195+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
192196
with:
193197
path: ${{ steps.pip-cache.outputs.dir }}
194198
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -227,7 +231,7 @@ jobs:
227231
python -m pip install --upgrade pip wheel
228232
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
229233
- name: pip cache
230-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
234+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
231235
with:
232236
path: ${{ steps.pip-cache.outputs.dir }}
233237
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,23 @@ jobs:
5050
pip install -U -r build/collect-profile-requirements.txt
5151
- name: Install JAX
5252
run: |
53-
pip uninstall -y jax jaxlib libtpu-nightly
53+
pip uninstall -y jax jaxlib libtpu
5454
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
5555
pip install .[tpu] \
5656
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5757
5858
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
5959
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu-nightly \
60+
pip install --pre libtpu \
6161
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6262
pip install requests
6363
6464
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
6565
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
66+
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
6667
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6768
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6869
pip install requests
69-
7070
else
7171
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7272
exit 1

.github/workflows/upstream-nightly.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
strategy:
3333
fail-fast: false
3434
matrix:
35-
python-version: ["3.12"]
35+
python-version: ["3.13"]
3636
outputs:
3737
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
3838
steps:
@@ -85,7 +85,7 @@ jobs:
8585
&& steps.status.outcome == 'failure'
8686
&& github.event_name == 'schedule'
8787
&& github.repository == 'jax-ml/jax'
88-
uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1
88+
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
8989
with:
9090
name: output-${{ matrix.python-version }}-log.jsonl
9191
path: output-${{ matrix.python-version }}-log.jsonl

.github/workflows/wheel_win_x64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [windows-2019-32core]
1919
arch: [AMD64]
20-
pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2']
20+
pyver: ['3.10', '3.11', '3.12', '3.13']
2121
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
2222
runs-on: ${{ matrix.os }}
2323

@@ -45,7 +45,7 @@ jobs:
4545
--bazel_options=--config=win_clang `
4646
--verbose
4747
48-
- uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1
48+
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
4949
with:
5050
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
5151
path: ${{ github.workspace }}\dist\*.whl

.github/workflows/windows_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
--bazel_options=--color=yes `
5454
--bazel_options=--config=win_clang
5555
56-
- uses: actions/upload-artifact@604373da6381bf24206979c74d06a550515601b9 # v4.4.1
56+
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
5757
with:
5858
name: wheels
5959
path: ${{ github.workspace }}\jax\dist\*.whl

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ repos:
2626
files: \.py$
2727

2828
- repo: https://github.com/astral-sh/ruff-pre-commit
29-
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
29+
rev: 8983acb92ee4b01924893632cf90af926fa608f0 # frozen: v0.7.0
3030
hooks:
3131
- id: ruff
3232

3333
- repo: https://github.com/pre-commit/mirrors-mypy
34-
rev: 'd4911cfb7f1010759fde68da196036feeb25b99d' # frozen: v1.11.2
34+
rev: '102bbee94061ff02fd361ec29c27b7cb26582f5f' # frozen: v1.12.2
3535
hooks:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)
3838
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
39-
additional_dependencies: [types-requests==2.31.0, jaxlib==0.4.31, ml_dtypes==0.3.2, numpy==1.26.3, scipy==1.11.4]
39+
additional_dependencies: [types-requests==2.31.0, jaxlib]
4040
args: [--config=pyproject.toml]
4141

4242
- repo: https://github.com/mwouts/jupytext

CHANGELOG.md

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2020
JAX version 0.4.26. Now we removed it.
2121
See {jax-issue}`#20385` for a discussion of alternatives.
2222

23+
* Changes:
24+
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
25+
operations. The semi-public API `jax.lib.xla_client.FftType` has been
26+
deprecated.
27+
* TPU: JAX now installs TPU support from the `libtpu` package rather than
28+
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
29+
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
30+
will be removed in Q1 2025.
31+
32+
* Deprecations:
33+
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
34+
No JAX APIs consume this type, so there is no replacement.
35+
* The default behavior of {func}`jax.pure_callback` and
36+
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
37+
the `vectorized` parameter to those functions. The `vmap_method` parameter
38+
should be used instead for better defined behavior. See the discussion in
39+
{jax-issue}`#23881` for more details.
40+
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
41+
been deprecated. Use the JAX FFI instead.
42+
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
43+
`jax.lib.xla_client.ops`,
44+
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
45+
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
46+
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO
47+
instead.
48+
2349
## jax 0.4.34 (October 4, 2024)
2450

2551
* New Functionality
@@ -56,11 +82,6 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
5682
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
5783
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
5884
`jax.errors.JaxRuntimeError` instead.
59-
* The default behavior of {func}`jax.pure_callback` and
60-
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
61-
the `vectorized` parameter to those functions. The `vmap_method` parameter
62-
should be used instead for better defined behavior. See the discussion in
63-
{jax-issue}`#23881` for more details.
6485

6586
* Deletion:
6687
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,24 @@ Some standouts:
387387

388388
### Supported platforms
389389

390-
| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac ARM | Windows x86_64 | Windows WSL2 x86_64 |
390+
| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 |
391391
|------------|--------------|---------------|--------------|--------------|----------------|---------------------|
392392
| CPU | yes | yes | yes | yes | yes | yes |
393393
| NVIDIA GPU | yes | yes | no | n/a | no | experimental |
394394
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
395-
| AMD GPU | experimental | no | no | n/a | no | no |
396-
| Apple GPU | n/a | no | experimental | experimental | n/a | n/a |
395+
| AMD GPU | yes | no | experimental | n/a | no | no |
396+
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
397397

398398

399399
### Instructions
400400

401-
| Hardware | Instructions |
402-
|------------|-----------------------------------------------------------------------------------------------------------------|
403-
| CPU | `pip install -U jax` |
404-
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
405-
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
406-
| AMD GPU | Use [Docker](https://hub.docker.com/r/rocm/jax) or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
407-
| Apple GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
401+
| Platform | Instructions |
402+
|-----------------|-----------------------------------------------------------------------------------------------------------------|
403+
| CPU | `pip install -U jax` |
404+
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
405+
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
406+
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
407+
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
408408

409409
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
410410
for information on alternative installation strategies. These include compiling

build/build.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ def write_bazelrc(*, remote_build,
285285
if enable_cuda:
286286
f.write("build --config=cuda\n")
287287
if use_cuda_nvcc:
288-
f.write("build --config=cuda_nvcc\n")
288+
f.write("build --config=build_cuda_with_nvcc\n")
289289
else:
290-
f.write("build --config=cuda_clang\n")
290+
f.write("build --config=build_cuda_with_clang\n")
291291
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
292292
if not enable_nccl:
293293
f.write("build --config=nonccl\n")
@@ -301,9 +301,12 @@ def write_bazelrc(*, remote_build,
301301
f.write(
302302
f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n')
303303
if enable_rocm:
304-
f.write("build --config=rocm\n")
304+
f.write("build --config=rocm_base\n")
305305
if not enable_nccl:
306306
f.write("build --config=nonccl\n")
307+
if use_clang:
308+
f.write("build --config=rocm\n")
309+
f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n")
307310
if python_version:
308311
f.write(
309312
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
@@ -495,7 +498,7 @@ def main():
495498
help="A comma-separated list of CUDA compute capabilities to support.")
496499
parser.add_argument(
497500
"--rocm_amdgpu_targets",
498-
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
501+
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
499502
help="A comma-separated list of ROCm amdgpu targets to support.")
500503
parser.add_argument(
501504
"--rocm_path",

0 commit comments

Comments
 (0)