Skip to content

Commit cf4a68e

Browse files
committed
Merge branch 'main' of https://github.com/google-ml-infra/jax-fork into srnitin/task-jax-ci-rework
2 parents 3aedc18 + f90bbfa commit cf4a68e

File tree

257 files changed

+10616
-3803
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

257 files changed

+10616
-3803
lines changed

ci/.bazelrc renamed to .bazelrc

Lines changed: 40 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
# Make Bazel print out all options from rc files.
55
build --announce_rc
66

7-
# Required by OpenXLA
8-
# https://github.com/openxla/xla/issues/1323
9-
build --nocheck_visibility
10-
117
# By default, execute all actions locally.
128
build --spawn_strategy=local
139

@@ -17,9 +13,6 @@ build --enable_platform_specific_config
1713

1814
build --experimental_cc_shared_library
1915

20-
# Disable enabled-by-default TensorFlow features that we don't care about.
21-
build --define=no_gcp_support=true
22-
2316
# Do not use C-Ares when building gRPC.
2417
build --define=grpc_no_ares=true
2518

@@ -33,12 +26,9 @@ build --output_filter=DONT_MATCH_ANYTHING
3326

3427
build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
3528

36-
build --verbose_failures=true
37-
3829
# #############################################################################
3930
# Platform Specific configs below. These are automatically picked up by Bazel
40-
# depending on the platform that is running the build. If you would like to
41-
# disable this behavior, pass in `--noenable_platform_specific_config`
31+
# depending on the platform that is running the build.
4232
# #############################################################################
4333
build:linux --config=posix
4434
build:linux --copt=-Wno-unknown-warning-option
@@ -56,7 +46,7 @@ build:macos --apple_platform_type=macos
5646
build:windows --features=compiler_param_file
5747
build:windows --features=archive_param_file
5848

59-
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
49+
# XLA uses M_* math constants that only get defined by MSVC headers if
6050
# _USE_MATH_DEFINES is defined.
6151
build:windows --copt=/D_USE_MATH_DEFINES
6252
build:windows --host_copt=/D_USE_MATH_DEFINES
@@ -81,10 +71,10 @@ build:windows --host_linkopt=/OPT:ICF
8171
build:windows --incompatible_strict_action_env=true
8272

8373
# #############################################################################
84-
# Feature-specific configurations. These are used by the Local and CI configs
85-
# below depending on the type of build. E.g. `local_linux_x86_64` inherits the
86-
# Linux x86 configs such as `avx_linux` and `mkl_open_source_only`,
87-
# `local_cuda_base` inherits `cuda` and `build_cuda_with_nvcc`, etc.
74+
# Feature-specific configurations. These are used by the CI configs below
75+
# depending on the type of build. E.g. `ci_linux_x86_64` inherits the Linux x86
76+
# configs such as `avx_linux` and `mkl_open_source_only`, `ci_linux_x86_64_cuda`
77+
# inherits `cuda` and `build_cuda_with_nvcc`, etc.
8878
# #############################################################################
8979
build:nonccl --define=no_nccl_support=true
9080

@@ -158,83 +148,41 @@ build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-c
158148
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
159149
build:win_clang --compiler=clang-cl
160150

161-
# Configs for building ROCM
162-
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
163-
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
164-
build:rocm --repo_env TF_NEED_ROCM=1
165-
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
151+
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
152+
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
153+
build:rocm_base --repo_env TF_NEED_ROCM=1
154+
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100"
155+
156+
# Build with hipcc for ROCm and clang for the host.
157+
build:rocm --config=rocm_base
158+
build:rocm --action_env=TF_ROCM_CLANG="1"
159+
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
160+
build:rocm --copt=-Wno-gnu-offsetof-extensions
161+
build:rocm --copt=-Qunused-arguments
162+
build:rocm --action_env=TF_HIPCC_CLANG="1"
166163

167164
# #############################################################################
168165
# Cache options below.
169166
# #############################################################################
170-
# Public read-only cache for macOS builds. The "oct2023" in the URL is just the
171-
# date when the bucket was created and can be disregarded. It still contains the
172-
# latest cache that is being used.
167+
# Public read-only cache for Mac builds. JAX uses a GCS bucket to store cache
168+
# from JAX's Mac CI build. By applying --config=macos_cache, any local Mac build
169+
# should be able to read from this cache and potentially see a speedup. The
170+
# "oct2023" in the URL is just the date when the bucket was created and can be
171+
# disregarded. It still contains the latest cache that is being used.
173172
build:macos_cache --remote_cache="https://storage.googleapis.com/tensorflow-macos-bazel-cache/oct2023" --remote_upload_local_results=false
174-
# Cache pushes are limited to Jax's CI system.
175-
build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials
176-
177-
# #############################################################################
178-
# Local Build config options below. Use these configs to build JAX locally.
179-
# #############################################################################
180-
# Set base CUDA configs. These are inherited by the Linux x86 and Linux Aarch64
181-
# CUDA configs.
182-
build:local_cuda_base --config=cuda
183-
184-
# JAX uses NVCC to build CUDA targets. If you would like to build CUDA targets
185-
# with Clang, change this to `--config=build_cuda_with_clang`
186-
build:local_cuda_base --config=build_cuda_with_nvcc
187-
188-
# Linux x86 Local configs
189-
build:local_linux_x86_64 --config=avx_linux
190-
build:local_linux_x86_64 --config=avx_posix
191-
build:local_linux_x86_64 --config=mkl_open_source_only
192-
193-
build:local_linux_x86_64_cuda --config=local_linux_x86_64
194-
build:local_linux_x86_64_cuda --config=local_cuda_base
195-
196-
# Linux Aarch64 Local configs
197-
# No custom config for Linux Aarch64. If building for CPU, run
198-
# `bazel build|test //path/to:target`. If building for CUDA, run
199-
# `bazel build|test --config=local_cuda_base //path/to:target`.
200-
build:local_linux_aarch64_cuda --config=local_cuda_base
201-
202-
# Mac x86 Local configs
203-
# For Mac x86, we target compatibility with macOS 10.14.
204-
build:local_darwin_x86_64 --macos_minimum_os=10.14
205-
# Read-only cache to boost build times.
206-
build:local_darwin_x86_64 --config=macos_cache
207-
208-
# Mac Arm64 CI configs
209-
# For Mac Arm64, we target compatibility with macOS 12.
210-
build:local_darwin_arm64 --macos_minimum_os=12.0
211-
# Read-only cache to boost build times.
212-
build:local_darwin_arm64 --config=macos_cache_push
213173

214-
# Windows x86 Local configs
215-
build:local_windows_amd64 --config=avx_windows
174+
# Cache pushes are limited to JAX's CI system.
175+
build:macos_cache_push --config=macos_cache --remote_upload_local_results=true --google_default_credentials
216176

217177
# #############################################################################
218178
# CI Build config options below.
219179
# JAX uses these configs in CI builds for building artifacts and when running
220180
# Bazel tests.
221-
#
222-
# These configs are pretty much the same as the local build configs above. The
223-
# difference is that, in CI, we build with Clang for and pass in a custom
224-
# non-hermetic toolchain to ensure manylinux compliance for Linux builds and
225-
# for using RBE on Windows. Because the toolchain is non-hermetic, it requires
226-
# specific versions of the compiler and other tools to be present on the system
227-
# in specific locations, which is why the Linux and Windows builds are run in a
228-
# Docker container.
229181
# #############################################################################
230-
231182
# Linux x86 CI configs
232-
# Inherit the local Linux x86 configs.
233-
build:ci_linux_x86_64 --config=local_linux_x86_64
234-
235-
# CI builds use Clang as the default compiler so we inherit Clang
236-
# specific configs
237-
build:ci_linux_x86_64 --config=clang
183+
build:ci_linux_x86_64 --config=avx_linux --config=avx_posix
184+
build:ci_linux_x86_64 --config=mkl_open_source_only
185+
build:ci_linux_x86_64 --config=clang --verbose_failures=true
238186

239187
# TODO(b/356695103): We do not have a CPU only toolchain so we use the CUDA
240188
# toolchain for both CPU and GPU builds.
@@ -249,45 +197,42 @@ build:ci_linux_x86_64 --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bi
249197
# The toolchain in `--config=cuda` needs to be read before the toolchain in
250198
# `--config=ci_linux_x86_64`. Otherwise, we run into issues with manylinux
251199
# compliance.
252-
build:ci_linux_x86_64_cuda --config=local_cuda_base
200+
build:ci_linux_x86_64_cuda --config=cuda --config=build_cuda_with_nvcc
253201
build:ci_linux_x86_64_cuda --config=ci_linux_x86_64
254202

255203
# Linux Aarch64 CI configs
256-
build:ci_linux_aarch64_base --config=clang
204+
build:ci_linux_aarch64_base --config=clang --verbose_failures=true
257205
build:ci_linux_aarch64_base --action_env=TF_SYSROOT="/dt10"
258206

259207
build:ci_linux_aarch64 --config=ci_linux_aarch64_base
260208
build:ci_linux_aarch64 --host_crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
261209
build:ci_linux_aarch64 --crosstool_top="@ml2014_clang_aarch64_config_aarch64//crosstool:toolchain"
262210

263-
# CUDA configs for Linux Aarch64 do not pass in the crosstool top flag from
211+
# CUDA configs for Linux Aarch64 do not pass in the crosstool_top flag from
264212
# above because the Aarch64 toolchain rule does not support building with NVCC.
265213
# Instead, we use `@local_config_cuda//crosstool:toolchain` from --config=cuda
266214
# and set `CLANG_CUDA_COMPILER_PATH` to define the toolchain so that we can
267215
# use Clang for the C++ targets and NVCC to build CUDA targets.
268216
build:ci_linux_aarch64_cuda --config=ci_linux_aarch64_base
269-
build:ci_linux_aarch64_cuda --config=local_cuda_base
217+
build:ci_linux_aarch64_cuda --config=cuda --config=build_cuda_with_nvcc
270218
build:ci_linux_aarch64_cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
271219

272220
# Mac x86 CI configs
273-
build:ci_darwin_x86_64 --config=local_darwin_x86_64
274-
# Mac CI builds read and push cache to/from GCS bucket.
221+
build:ci_darwin_x86_64 --macos_minimum_os=10.14
275222
build:ci_darwin_x86_64 --config=macos_cache_push
223+
build:ci_darwin_x86_64 --verbose_failures=true
276224

277225
# Mac Arm64 CI configs
278-
build:ci_darwin_arm64 --config=local_darwin_arm64
279-
# CI builds read and push cache to/from GCS bucket.
226+
build:ci_darwin_arm64 --macos_minimum_os=11.0
280227
build:ci_darwin_arm64 --config=macos_cache_push
228+
build:ci_darwin_arm64 --verbose_failures=true
281229

282230
# Windows x86 CI configs
283-
build:ci_windows_amd64 --config=local_windows_amd64
284-
build:ci_windows_amd64 --config=clang
285-
# Set the toolchains
231+
build:ci_windows_amd64 --config=avx_windows
232+
build:ci_windows_amd64 --compiler=clang-cl --config=clang --verbose_failures=true
286233
build:ci_windows_amd64 --crosstool_top="@xla//tools/toolchains/win/20240424:toolchain"
287234
build:ci_windows_amd64 --extra_toolchains="@xla//tools/toolchains/win/20240424:cc-toolchain-x64_windows-clang-cl"
288-
build:ci_windows_amd64 --compiler=clang-cl
289-
build:ci_windows_amd64 --linkopt=/FORCE:MULTIPLE
290-
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE
235+
build:ci_windows_amd64 --host_linkopt=/FORCE:MULTIPLE --linkopt=/FORCE:MULTIPLE
291236

292237
# #############################################################################
293238
# RBE config options below. These inherit the CI configs above and set the
@@ -333,9 +278,6 @@ build:rbe_linux_x86_64_base --host_platform="@ubuntu20.04-clang_manylinux2014-cu
333278
build:rbe_linux_x86_64_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
334279
build:rbe_linux_x86_64_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform"
335280

336-
# Python config is the same across all containers because the binary is the same
337-
build:rbe_linux_x86_64_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python"
338-
339281
build:rbe_linux_x86_64 --config=rbe_linux_x86_64_base
340282
build:rbe_linux_x86_64 --config=ci_linux_x86_64
341283

@@ -365,7 +307,7 @@ build:rbe_windows_amd64 --config=ci_windows_amd64
365307

366308
# #############################################################################
367309
# Cross-compile config options below. Native RBE support does not exist for
368-
# Linux Aarch64 and Mac x86. So, we use the cross-compile toolchain to build
310+
# Linux Aarch64 and Mac x86. So, we use a cross-compile toolchain to build
369311
# targets for Linux Aarch64 and Mac x86 on the Linux x86 RBE pool.
370312
# #############################################################################
371313
# Set execution platform to Linux x86
@@ -415,38 +357,6 @@ build:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings
415357
build:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64
416358
build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base
417359

418-
# #############################################################################
419-
# Test specific config options below. These are used when `bazel test` is run.
420-
# #############################################################################
421-
test --test_output=errors
422-
423-
# Common configs for for running GPU tests.
424-
test:gpu --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
425-
426-
# Non-multiaccelerator tests with one GPU apiece. These tests are run on RBE
427-
# and locally.
428-
test:non_multiaccelerator --config=gpu
429-
test:non_multiaccelerator --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow
430-
test:non_multiaccelerator --test_tag_filters=-multiaccelerator
431-
432-
# Configs for running non-multiaccelerator tests locally
433-
test:non_multiaccelerator_local --config=non_multiaccelerator
434-
# Disable building jaxlib. Instead we depend on the local wheel.
435-
test:non_multiaccelerator_local --//jax:build_jaxlib=false
436-
437-
# `JAX_ACCELERATOR_COUNT` needs to match the number of GPUs in the VM.
438-
test:non_multiaccelerator_local --test_env=JAX_TESTS_PER_ACCELERATOR=12 --test_env=JAX_ACCELERATOR_COUNT=4
439-
440-
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
441-
# should match the VM's CPU core count (set in `--local_test_jobs`).
442-
test:non_multiaccelerator_local --local_test_jobs=48
443-
444-
# Multiaccelerator tests with all GPUs. These tests are only run locally
445-
# Disable building jaxlib. Instead we depend on the local wheel.
446-
test:multiaccelerator_local --config=gpu
447-
test:multiaccelerator_local --//jax:build_jaxlib=false
448-
test:multiaccelerator_local --jobs=8 --test_tag_filters=multiaccelerator
449-
450360
#############################################################################
451361
# Some configs to make getting some forms of debug builds. In general, the
452362
# codebase is only regularly built with optimizations. Use 'debug_symbols' to

.github/workflows/ci-build.yaml

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,31 @@ jobs:
3737
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
3838

3939
build:
40-
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
41-
runs-on: ${{ matrix.os }}
40+
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ubuntu-20.04, x64=${{ matrix.enable-x64}})"
41+
runs-on: linux-x86-n2-32
42+
container:
43+
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
4244
timeout-minutes: 60
4345
strategy:
4446
matrix:
4547
# Test the oldest and newest supported Python versions here.
4648
include:
4749
- name-prefix: "with 3.10"
4850
python-version: "3.10"
49-
os: ubuntu-20.04-16core
5051
enable-x64: 1
5152
prng-upgrade: 1
5253
num_generated_cases: 1
53-
- name-prefix: "with 3.12"
54-
python-version: "3.12"
55-
os: ubuntu-20.04-16core
54+
- name-prefix: "with 3.13"
55+
python-version: "3.13"
5656
enable-x64: 0
5757
prng-upgrade: 0
5858
num_generated_cases: 1
5959
steps:
6060
- uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1
61+
- name: Image Setup
62+
run: |
63+
apt update
64+
apt install -y libssl-dev
6165
- name: Set up Python ${{ matrix.python-version }}
6266
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0
6367
with:
@@ -68,7 +72,7 @@ jobs:
6872
python -m pip install --upgrade pip wheel
6973
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
7074
- name: pip cache
71-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
75+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
7276
with:
7377
path: ${{ steps.pip-cache.outputs.dir }}
7478
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -115,7 +119,7 @@ jobs:
115119
python -m pip install --upgrade pip wheel
116120
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
117121
- name: pip cache
118-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
122+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
119123
with:
120124
path: ${{ steps.pip-cache.outputs.dir }}
121125
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -152,7 +156,7 @@ jobs:
152156
python -m pip install --upgrade pip wheel
153157
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
154158
- name: pip cache
155-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
159+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
156160
with:
157161
path: ${{ steps.pip-cache.outputs.dir }}
158162
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -188,7 +192,7 @@ jobs:
188192
python -m pip install --upgrade pip wheel
189193
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
190194
- name: pip cache
191-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
195+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
192196
with:
193197
path: ${{ steps.pip-cache.outputs.dir }}
194198
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -227,7 +231,7 @@ jobs:
227231
python -m pip install --upgrade pip wheel
228232
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
229233
- name: pip cache
230-
uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0
234+
uses: actions/cache@3624ceb22c1c5a301c8db4169662070a689d9ea8 # v4.1.1
231235
with:
232236
path: ${{ steps.pip-cache.outputs.dir }}
233237
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,23 @@ jobs:
5050
pip install -U -r build/collect-profile-requirements.txt
5151
- name: Install JAX
5252
run: |
53-
pip uninstall -y jax jaxlib libtpu-nightly
53+
pip uninstall -y jax jaxlib libtpu
5454
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
5555
pip install .[tpu] \
5656
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5757
5858
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
5959
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu-nightly \
60+
pip install --pre libtpu \
6161
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6262
pip install requests
6363
6464
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
6565
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
66+
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
6667
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6768
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6869
pip install requests
69-
7070
else
7171
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7272
exit 1

0 commit comments

Comments
 (0)