Skip to content

Commit f8b2c2b

Browse files
committed
Merge branch 'main' of https://github.com/google/jax
2 parents fd7a21e + aa73aa0 commit f8b2c2b

File tree

395 files changed

+10860
-8611
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

395 files changed

+10860
-8611
lines changed

.bazelrc

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@ build:native_arch_posix --host_copt=-march=native
5757

5858
build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
5959

60+
build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang"
61+
# Disable clang extention that rejects type definitions within offsetof.
62+
# This was added in clang-16 by https://reviews.llvm.org/D133574.
63+
# Can be removed once upb is updated, since a type definition is used within
64+
# offset of in the current version of ubp.
65+
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
66+
build:clang --copt=-Wno-gnu-offsetof-extensions
67+
# Disable clang extention that rejects unknown arguments.
68+
build:clang --copt=-Qunused-arguments
69+
6070
build:cuda --repo_env TF_NEED_CUDA=1
6171
build:cuda --repo_env TF_NCCL_USE_STUB=1
6272
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
@@ -68,14 +78,6 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
6878
# Default hermetic CUDA and CUDNN versions.
6979
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
7080
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
71-
# This flag is needed to include CUDA libraries for bazel tests.
72-
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
73-
74-
# Requires MSVC and LLVM to be installed
75-
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
76-
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
77-
build:win_clang --compiler=clang-cl
78-
7981
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
8082
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
8183
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
@@ -89,23 +91,18 @@ build:win_clang --compiler=clang-cl
8991
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
9092
# The list of CUDA pip packages that JAX depends on are present in setup.py.
9193
build:cuda --linkopt=-Wl,--disable-new-dtags
94+
build:cuda --@local_config_cuda//:cuda_compiler=clang
95+
build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
9296

93-
build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
94-
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
95-
# Disable clang extention that rejects type definitions within offsetof.
96-
# This was added in clang-16 by https://reviews.llvm.org/D133574.
97-
# Can be removed once upb is updated, since a type definition is used within
98-
# offset of in the current version of ubp.
99-
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
100-
build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
101-
# Disable clang extention that rejects unknown arguments.
102-
build:cuda_clang --copt=-Qunused-arguments
97+
# This flag is needed to include CUDA libraries for bazel tests.
98+
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true
10399

104-
# Build with nvcc for CUDA and clang for host
105-
build:nvcc_clang --config=cuda
106-
build:nvcc_clang --config=cuda_clang
107-
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
108-
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
100+
# Build with NVCC for CUDA
101+
build:cuda_nvcc --config=cuda
102+
build:cuda_nvcc --config=clang
103+
build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc
104+
build:cuda_nvcc --action_env=TF_NVCC_CLANG="1"
105+
build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
109106

110107
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
111108
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
@@ -114,6 +111,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1
114111

115112
build:nonccl --define=no_nccl_support=true
116113

114+
# Requires MSVC and LLVM to be installed
115+
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
116+
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
117+
build:win_clang --compiler=clang-cl
118+
117119
# Windows has a relatively short command line limit, which JAX has begun to hit.
118120
# See https://docs.bazel.build/versions/main/windows.html
119121
build:windows --features=compiler_param_file
@@ -200,7 +202,7 @@ build:rbe_linux --host_linkopt=-lm
200202
# Use the GPU toolchain until the CPU one is ready.
201203
# https://github.com/bazelbuild/bazel/issues/13623
202204
build:rbe_cpu_linux_base --config=rbe_linux
203-
build:rbe_cpu_linux_base --config=cuda_clang
205+
build:rbe_cpu_linux_base --config=clang
204206
build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
205207
build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain"
206208
build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64"
@@ -215,13 +217,15 @@ build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base
215217
build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
216218
build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base
217219
build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
220+
build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base
221+
build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"
218222

219223
build:rbe_linux_cuda_base --config=rbe_linux
220224
build:rbe_linux_cuda_base --config=cuda
221225
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
222226

223227
build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
224-
build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang
228+
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc
225229
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
226230
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
227231
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
@@ -237,6 +241,8 @@ build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base
237241
build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11"
238242
build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base
239243
build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12"
244+
build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base
245+
build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13"
240246

241247
# These you may need to change for your own GCP project.
242248
build:tensorflow_testing_rbe --project_id=tensorflow-testing

.github/ISSUE_TEMPLATE/bug-report.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ body:
2020
* If you prefer a non-templated issue report, click [here][Raw report].
2121
2222
23-
[Discussions]: https://github.com/google/jax/discussions
23+
[Discussions]: https://github.com/jax-ml/jax/discussions
2424
25-
[issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues
25+
[issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues
2626
27-
[Raw report]: http://github.com/google/jax/issues/new
27+
[Raw report]: http://github.com/jax-ml/jax/issues/new
2828
- type: textarea
2929
attributes:
3030
label: Description

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
blank_issues_enabled: false
22
contact_links:
33
- name: Have questions or need support?
4-
url: https://github.com/google/jax/discussions
4+
url: https://github.com/jax-ml/jax/discussions
55
about: Please ask questions on the Discussions tab

.github/workflows/jax-array-api.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
name: JAX Array API
22

33
on:
4-
workflow_dispatch: # allows triggering the workflow run manually
5-
pull_request: # Automatically trigger on pull requests affecting particular files
4+
push:
5+
branches:
6+
- main
7+
pull_request:
68
branches:
79
- main
8-
paths:
9-
- '**workflows/jax-array-api.yml'
10-
- '**experimental/array_api/**'
1110

1211
concurrency:
1312
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}

.github/workflows/self_hosted_runner_utils/setup_runner.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ runner_token="$3"
3131
# - sets empty string as default to avoid unbound variable error from set -u
3232
jax_repo_url="${4-}"
3333
if [ -z "${jax_repo_url}" ]; then
34-
jax_repo_url="https://github.com/google/jax"
34+
jax_repo_url="https://github.com/jax-ml/jax"
3535
fi
3636

3737
# Create `runner` user. This user won't have sudo access unless you ssh into the
@@ -67,7 +67,7 @@ cd ~/
6767
6868
git clone ${jax_repo_url}
6969
70-
# Based on https://github.com/google/jax/settings/actions/runners/new
70+
# Based on https://github.com/jax-ml/jax/settings/actions/runners/new
7171
# (will be 404 for github users with insufficient repo permissions)
7272
mkdir actions-runner && cd actions-runner
7373
curl -o actions-runner-linux-x64.tar.gz -L ${actions_runner_download}

.github/workflows/upstream-nightly.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ jobs:
8484
failure()
8585
&& steps.status.outcome == 'failure'
8686
&& github.event_name == 'schedule'
87-
&& github.repository == 'google/jax'
88-
uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
87+
&& github.repository == 'jax-ml/jax'
88+
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
8989
with:
9090
name: output-${{ matrix.python-version }}-log.jsonl
9191
path: output-${{ matrix.python-version }}-log.jsonl

.github/workflows/wheel_win_x64.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [windows-2019-32core]
1919
arch: [AMD64]
20-
pyver: ['3.10', '3.11', '3.12']
20+
pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2']
2121
name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build
2222
runs-on: ${{ matrix.os }}
2323

@@ -45,7 +45,7 @@ jobs:
4545
--bazel_options=--config=win_clang `
4646
--verbose
4747
48-
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
48+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
4949
with:
5050
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
5151
path: ${{ github.workspace }}\dist\*.whl

.github/workflows/windows_ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
--bazel_options=--color=yes `
5454
--bazel_options=--config=win_clang
5555
56-
- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
56+
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4
5757
with:
5858
name: wheels
5959
path: ${{ github.workspace }}\jax\dist\*.whl

0 commit comments

Comments
 (0)