Skip to content

Commit 6fe9a71

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 53f20d9 + 9584ee3 commit 6fe9a71

File tree

215 files changed

+7489
-3169
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

215 files changed

+7489
-3169
lines changed

.github/workflows/asan.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212
branches:
1313
- main
1414
paths:
15-
- '**/workflows/asan.yml'
15+
- '**/workflows/asan.yaml'
1616

1717
jobs:
1818
asan:
Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run Bazel CPU tests (RBE)
1+
name: CI - Bazel CPU tests (RBE)
22

33
on:
44
# pull_request:
@@ -15,32 +15,31 @@ on:
1515
- 'yes'
1616
- 'no'
1717

18+
concurrency:
19+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
20+
cancel-in-progress: true
21+
1822
jobs:
19-
run_bazel_rbe_cpu_tests:
20-
continue-on-error: true
21-
defaults:
22-
run:
23-
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd.
24-
shell: bash
23+
run_tests:
24+
if: github.event.repository.fork == false
2525
strategy:
2626
matrix:
2727
runner: ["windows-x86-n2-64", "linux-x86-n2-16", "linux-arm64-t2a-16"]
2828

2929
runs-on: ${{ matrix.runner }}
30-
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12') ||
31-
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') ||
30+
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
31+
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
3232
(contains(matrix.runner, 'windows-x86') && null) }}
3333

34+
3435
env:
35-
JAXCI_CLONE_MAIN_XLA: 1
3636
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
3737

3838
steps:
3939
- uses: actions/checkout@v3
40-
# Halt for testing
4140
- name: Wait For Connection
4241
uses: google-ml-infra/actions/ci_connection@main
4342
with:
4443
halt-dispatch-input: ${{ inputs.halt-for-connection }}
45-
- name: Run Bazel CPU Tests
44+
- name: Run Bazel CPU Tests with RBE
4645
run: ./ci/run_bazel_test_cpu_rbe.sh

.github/workflows/bazel_gpu_rbe.yml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run Bazel GPU tests (RBE)
1+
name: CI - Bazel GPU tests (RBE)
22

33
on:
44
# pull_request:
@@ -15,26 +15,28 @@ on:
1515
- 'yes'
1616
- 'no'
1717

18+
concurrency:
19+
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
20+
cancel-in-progress: true
21+
1822
jobs:
19-
build:
23+
run_tests:
24+
if: github.event.repository.fork == false
2025
strategy:
2126
matrix:
2227
runner: ["linux-x86-n2-16"]
2328

2429
runs-on: ${{ matrix.runner }}
25-
container:
26-
image: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
30+
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'
2731

2832
env:
29-
JAXCI_CLONE_MAIN_XLA: 1
30-
JAXCI_HERMETIC_PYTHON_VERSION: 3.12
33+
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
3134

3235
steps:
3336
- uses: actions/checkout@v3
34-
# Halt for testing
3537
- name: Wait For Connection
3638
uses: google-ml-infra/actions/ci_connection@main
3739
with:
3840
halt-dispatch-input: ${{ inputs.halt-for-connection }}
39-
- name: Run Bazel GPU tests using RBE
41+
- name: Run Bazel GPU Tests with RBE
4042
run: ./ci/run_bazel_test_gpu_rbe.sh

.github/workflows/ci-build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
with:
3636
python-version: 3.11
3737
- run: python -m pip install pre-commit
38-
- uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
38+
- uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
3939
with:
4040
path: ~/.cache/pre-commit
4141
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}

.github/workflows/jax-array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'bcd5919bbbdf4d4806b5b2613b4d8c0bc0625c54' # Latest commit as of 2024-10-31 👻
31+
ref: 'a3f3f376308e64f0ac15b307dfe27be945409e41' # Latest commit as of 2024-11-14
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}

CHANGELOG.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4040
`platforms` instead.
4141
* Hashing of tracers, which has been deprecated since version 0.4.30, now
4242
results in a `TypeError`.
43+
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
44+
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
45+
on the function inputs.
46+
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
47+
return NaN for negative integer inputs, to match the behavior of SciPy from
48+
https://github.com/scipy/scipy/pull/21827.
49+
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
4350

4451
* New Features
4552
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
@@ -48,6 +55,21 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4855
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
4956
declared inline via {func}`dataclasses.field`. See the function documentation
5057
for examples.
58+
* Added {func}`jax.numpy.put_along_axis`.
59+
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
60+
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
61+
supported on GPU. See {jax-issue}`#24663` for more details.
62+
63+
* Bug fixes
64+
* Fixed a bug where the GPU implementations of LU and QR decomposition would
65+
result in an indexing overflow for batch sizes close to int32 max. See
66+
{jax-issue}`#24843` for more details.
67+
68+
* Deprecations
69+
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
70+
use `jax.Array` instead.
71+
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
72+
instead.
5173

5274
## jax 0.4.35 (Oct 22, 2024)
5375

@@ -79,7 +101,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
79101
* The semi-public API `jax.lib.xla_client.register_custom_call_target` has
80102
been deprecated. Use the JAX FFI instead.
81103
* The semi-public APIs `jax.lib.xla_client.dtype_to_etype`,
82-
`jax.lib.xla_client.ops`,
104+
`jax.lib.xla_client.ops`,
83105
`jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`,
84106
`jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and
85107
`jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ You can mix `jit` and `grad` and any other JAX transformation however you like.
193193

194194
Using `jit` puts constraints on the kind of Python control flow
195195
the function can use; see
196-
the [Gotchas
197-
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
196+
the tutorial on [Control Flow and Logical Operators with JIT](https://jax.readthedocs.io/en/latest/control-flow.html)
198197
for more.
199198

200199
### Auto-vectorization with `vmap`
@@ -353,7 +352,7 @@ Some standouts:
353352
1. [In-place mutating updates of
354353
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
355354
1. [Random numbers are
356-
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
355+
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md).
357356
1. If you're looking for [convolution
358357
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
359358
they're in the `jax.lax` package.
@@ -373,7 +372,7 @@ Some standouts:
373372
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
374373
np.float32)).dtype` is `float64` rather than `float32`.
375374
1. Some transformations, like `jit`, [constrain how you can use Python control
376-
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
375+
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
377376
You'll always get loud errors if something goes wrong. You might have to use
378377
[`jit`'s `static_argnums`
379378
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
@@ -394,6 +393,7 @@ Some standouts:
394393
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
395394
| AMD GPU | yes | no | experimental | n/a | no | no |
396395
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
396+
| Intel GPU | experimental | n/a | n/a | n/a | no | no |
397397

398398

399399
### Instructions
@@ -405,6 +405,7 @@ Some standouts:
405405
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
406406
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
407407
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
408+
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |
408409

409410
See [the documentation](https://jax.readthedocs.io/en/latest/installation.html)
410411
for information on alternative installation strategies. These include compiling

ci/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# JAX continuous integration..
1+
# JAX continuous integration
22

3-
> **Warning** This folder is still under construction. It is part of an ongoing
3+
> [!WARNING]
4+
> This folder is still under construction. It is part of an ongoing
45
> effort to improve the structure of CI and build related files within the
56
> JAX repo. This warning will be removed when the contents of this
67
> directory are stable and appropriate documentation around its usage is in

ci/envs/default.env

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# This file contains all the default values for the "JAXCI_" environment
16-
# variables used in the CI scripts.
16+
# variables used in the CI scripts. These variables are used to control the
17+
# behavior of the CI scripts such as the Python version used, path to JAX/XLA
18+
# repo, if to clone XLA repo, etc.
1719

1820
# The path to the JAX git repository.
1921
export JAXCI_JAX_GIT_DIR=$(pwd)
@@ -22,8 +24,9 @@ export JAXCI_JAX_GIT_DIR=$(pwd)
2224
# set.
2325
export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')}
2426

25-
# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository if you want to
26-
# use a local copy of XLA instead of the pinned version in the WORKSPACE.
27+
# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local
28+
# copy of XLA instead of the pinned version in the WORKSPACE. When
29+
# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically.
2730
export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-}
2831

2932
# If set to 1, the builds will clone the XLA repository at HEAD and set its

ci/run_bazel_test_cpu_rbe.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# -o allexport: export all functions and variables to be available to subscripts
2323
set -exu -o history -o allexport
2424

25-
# Inherit default JAXCI environment variables.
25+
# Source default JAXCI environment variables.
2626
source ci/envs/default.env
2727

2828
# Clone XLA at HEAD if path to local XLA is not provided
@@ -33,6 +33,7 @@ fi
3333
# Set up the build environment.
3434
source "ci/utilities/setup_build_environment.sh"
3535

36+
# Run Bazel CPU tests with RBE.
3637
os=$(uname -s | awk '{print tolower($0)}')
3738
arch=$(uname -m)
3839

@@ -43,8 +44,8 @@ arch=$(uname -m)
4344
# single machine can take a long time, we skip running them on these
4445
# platforms.
4546
if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then
46-
echo "Building RBE CPU tests..."
47-
bazel build --config=rbe_cross_compile_${os}_${arch} \
47+
echo "Building RBE CPU tests..."
48+
bazel build --config=rbe_cross_compile_${os}_${arch} \
4849
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
4950
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
5051
--test_env=JAX_NUM_GENERATED_CASES=25 \
@@ -54,8 +55,8 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] )
5455
--color=yes \
5556
//tests:cpu_tests //tests:backend_independent_tests
5657
else
57-
echo "Running RBE CPU tests..."
58-
bazel test --config=rbe_${os}_${arch} \
58+
echo "Running RBE CPU tests..."
59+
bazel test --config=rbe_${os}_${arch} \
5960
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
6061
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
6162
--test_env=JAX_NUM_GENERATED_CASES=25 \

0 commit comments

Comments
 (0)