Skip to content

Commit 97c8d5d

Browse files
committed
Merge branch 'main' of https://github.com/google/jax
2 parents f8b2c2b + 816947b commit 97c8d5d

File tree

140 files changed

+6440
-3314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

140 files changed

+6440
-3314
lines changed

.bazelrc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
############################################################################
22
# All default build options below.
33

4-
# Required by OpenXLA
5-
# https://github.com/openxla/xla/issues/1323
6-
build --nocheck_visibility
7-
84
# Sets the default Apple platform to macOS.
95
build --apple_platform_type=macos
106
build --macos_minimum_os=10.14
@@ -57,7 +53,6 @@ build:native_arch_posix --host_copt=-march=native
5753

5854
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
5955

60-
build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang"
6156
# Disable clang extention that rejects type definitions within offsetof.
6257
# This was added in clang-16 by https://reviews.llvm.org/D133574.
6358
# Can be removed once upb is updated, since a type definition is used within
@@ -91,12 +86,14 @@ build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
9186
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
9287
# The list of CUDA pip packages that JAX depends on are present in setup.py.
9388
build:cuda --linkopt=-Wl,--disable-new-dtags
94-
build:cuda --@local_config_cuda//:cuda_compiler=clang
95-
build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
9689

9790
# This flag is needed to include CUDA libraries for bazel tests.
9891
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
9992

93+
build:cuda_clang --config=clang
94+
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
95+
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
96+
10097
# Build with NVCC for CUDA
10198
build:cuda_nvcc --config=cuda
10299
build:cuda_nvcc --config=clang
@@ -202,7 +199,7 @@ build:rbe_linux --host_linkopt=-lm
202199
# Use the GPU toolchain until the CPU one is ready.
203200
# https://github.com/bazelbuild/bazel/issues/13623
204201
build:rbe_cpu_linux_base --config=rbe_linux
205-
build:rbe_cpu_linux_base --config=clang
202+
build:rbe_cpu_linux_base --config=cuda_clang
206203
build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
207204
build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain"
208205
build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64"

.github/workflows/ci-build.yaml

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ jobs:
2929
runs-on: ubuntu-latest
3030
timeout-minutes: 5
3131
steps:
32-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
32+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3333
- name: Set up Python 3.11
34-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
34+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3535
with:
3636
python-version: 3.11
37-
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1
37+
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
3838

3939
build:
4040
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
@@ -57,9 +57,9 @@ jobs:
5757
prng-upgrade: 0
5858
num_generated_cases: 1
5959
steps:
60-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
60+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
6161
- name: Set up Python ${{ matrix.python-version }}
62-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
62+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
6363
with:
6464
python-version: ${{ matrix.python-version }}
6565
- name: Get pip cache dir
@@ -68,7 +68,7 @@ jobs:
6868
python -m pip install --upgrade pip wheel
6969
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
7070
- name: pip cache
71-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
71+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
7272
with:
7373
path: ${{ steps.pip-cache.outputs.dir }}
7474
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -104,9 +104,9 @@ jobs:
104104
matrix:
105105
python-version: ['3.10']
106106
steps:
107-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
107+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
108108
- name: Set up Python ${{ matrix.python-version }}
109-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
109+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
110110
with:
111111
python-version: ${{ matrix.python-version }}
112112
- name: Get pip cache dir
@@ -115,7 +115,7 @@ jobs:
115115
python -m pip install --upgrade pip wheel
116116
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
117117
- name: pip cache
118-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
118+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
119119
with:
120120
path: ${{ steps.pip-cache.outputs.dir }}
121121
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -141,9 +141,9 @@ jobs:
141141
matrix:
142142
python-version: ['3.10']
143143
steps:
144-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
144+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
145145
- name: Set up Python ${{ matrix.python-version }}
146-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
146+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
147147
with:
148148
python-version: ${{ matrix.python-version }}
149149
- name: Get pip cache dir
@@ -152,7 +152,7 @@ jobs:
152152
python -m pip install --upgrade pip wheel
153153
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
154154
- name: pip cache
155-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
155+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
156156
with:
157157
path: ${{ steps.pip-cache.outputs.dir }}
158158
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -177,9 +177,9 @@ jobs:
177177
enable-x64: 0
178178
num_generated_cases: 10
179179
steps:
180-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
180+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
181181
- name: Set up Python ${{ matrix.python-version }}
182-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
182+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
183183
with:
184184
python-version: ${{ matrix.python-version }}
185185
- name: Get pip cache dir
@@ -188,7 +188,7 @@ jobs:
188188
python -m pip install --upgrade pip wheel
189189
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
190190
- name: pip cache
191-
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
191+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
192192
with:
193193
path: ${{ steps.pip-cache.outputs.dir }}
194194
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -210,4 +210,37 @@ jobs:
210210
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
211211
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
212212
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
213-
213+
214+
ffi:
215+
name: FFI example
216+
runs-on: ubuntu-latest
217+
timeout-minutes: 5
218+
steps:
219+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
220+
- name: Set up Python 3.11
221+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
222+
with:
223+
python-version: 3.11
224+
- name: Get pip cache dir
225+
id: pip-cache
226+
run: |
227+
python -m pip install --upgrade pip wheel
228+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
229+
- name: pip cache
230+
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2
231+
with:
232+
path: ${{ steps.pip-cache.outputs.dir }}
233+
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
234+
- name: Install JAX
235+
run: pip install .
236+
- name: Build and install example project
237+
run: python -m pip install -v ./examples/ffi[test]
238+
env:
239+
# We test building using GCC instead of clang. All other JAX builds use
240+
# clang, but it is useful to make sure that FFI users can compile using
241+
# a different toolchain. GCC is the default compiler on the
242+
# 'ubuntu-latest' runner, but we still set this explicitly just to be
243+
# clear.
244+
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
245+
- name: Run tests
246+
run: python -m pytest examples/ffi/tests

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
# https://opensource.google/documentation/reference/github/services#actions
4444
# mandates using a specific commit for non-Google actions. We use
4545
# https://github.com/sethvargo/ratchet to pin specific versions.
46-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
46+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
4747
- name: Install JAX test requirements
4848
run: |
4949
pip install -U -r build/test-requirements.txt

.github/workflows/jax-array-api.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ jobs:
2222

2323
steps:
2424
- name: Checkout jax
25-
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4
25+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2626
- name: Checkout array-api-tests
27-
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet: actions/checkout@v4
27+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
3131
ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}
35-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
35+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3636
with:
3737
python-version: ${{ matrix.python-version }}
3838
- name: Install dependencies

.github/workflows/metal_plugin_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727

2828
steps:
2929
- name: Get repo
30-
uses: actions/checkout@v4
30+
uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3131
with:
3232
path: jax
3333
- name: Setup build and test enviroment

.github/workflows/upstream-nightly.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ jobs:
3636
outputs:
3737
artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }}
3838
steps:
39-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
39+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
4040
- name: Set up Python ${{ matrix.python-version }}
41-
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
41+
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
4242
with:
4343
python-version: ${{ matrix.python-version }}
4444
- name: Install JAX test requirements
@@ -85,7 +85,7 @@ jobs:
8585
&& steps.status.outcome == 'failure'
8686
&& github.event_name == 'schedule'
8787
&& github.repository == 'jax-ml/jax'
88-
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
88+
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
8989
with:
9090
name: output-${{ matrix.python-version }}-log.jsonl
9191
path: output-${{ matrix.python-version }}-log.jsonl
@@ -106,11 +106,11 @@ jobs:
106106
run:
107107
shell: bash
108108
steps:
109-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
110-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
109+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
110+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
111111
with:
112112
python-version: "3.x"
113-
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # ratchet:actions/download-artifact@v4
113+
- uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
114114
with:
115115
path: /tmp/workspace/logs
116116
- name: install requirements
@@ -123,7 +123,7 @@ jobs:
123123
cat logs/*.jsonl > pytest-logs.txt
124124
python .github/workflows/parse_logs.py pytest-logs.txt --outfile=parsed-logs.txt
125125
- name: Report failures
126-
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # ratchet:actions/github-script@v7
126+
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
127127
with:
128128
github-token: ${{ secrets.GITHUB_TOKEN }}
129129
script: |

.github/workflows/wheel_win_x64.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ jobs:
2525
- name: Install LLVM/Clang
2626
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
2727

28-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
28+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
2929

30-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
30+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3131
with:
3232
python-version: ${{ matrix.pyver }}
3333
cache: 'pip'
@@ -45,7 +45,7 @@ jobs:
4545
--bazel_options=--config=win_clang `
4646
--verbose
4747
48-
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
48+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
4949
with:
5050
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
5151
path: ${{ github.workspace }}\dist\*.whl

.github/workflows/windows_ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ jobs:
3131
- name: Install LLVM/Clang
3232
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade
3333

34-
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
34+
- uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0
3535
with:
3636
path: jax
3737

38-
- uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
38+
- uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
3939
with:
4040
python-version: ${{ matrix.pyver }}
4141
cache: 'pip'
@@ -53,7 +53,7 @@ jobs:
5353
--bazel_options=--color=yes `
5454
--bazel_options=--config=win_clang
5555
56-
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
56+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
5757
with:
5858
name: wheels
5959
path: ${{ github.workspace }}\jax\dist\*.whl
@@ -66,7 +66,7 @@ jobs:
6666
PY_COLORS: 1
6767
run: |
6868
cd jax
69+
python -m pip install --pre --find-links ${{ github.workspace }}\jax\dist jaxlib
6970
python -m pip install -e ${{ github.workspace }}\jax
70-
python -m pip install --no-index --find-links ${{ github.workspace }}\jax\dist jaxlib
7171
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
7272
pytest -n auto --tb=short tests examples

0 commit comments

Comments
 (0)