Skip to content

Commit 3b55033

Browse files
committed
update scripts
1 parent 36d04de commit 3b55033

File tree

11 files changed

+168
-22
lines changed

11 files changed

+168
-22
lines changed

ci/.bazelrc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,38 @@ build:cross_compile_darwin_x86_64 --platform_mappings=platform_mappings
415415
build:rbe_cross_compile_darwin_x86_64 --config=cross_compile_darwin_x86_64
416416
build:rbe_cross_compile_darwin_x86_64 --config=rbe_cross_compile_base
417417

418+
# #############################################################################
419+
# Test specific config options below. These are used when `bazel test` is run.
420+
# #############################################################################
421+
test --test_output=errors
422+
423+
# Common configs for for running GPU tests.
424+
test:gpu --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform
425+
426+
# Non-multiaccelerator tests with one GPU apiece. These tests are run on RBE
427+
# and locally.
428+
test:non_multiaccelerator --config=gpu
429+
test:non_multiaccelerator --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow
430+
test:non_multiaccelerator --test_tag_filters=-multiaccelerator
431+
432+
# Configs for running non-multiaccelerator tests locally
433+
test:non_multiaccelerator_local --config=non_multiaccelerator
434+
# Disable building jaxlib. Instead we depend on the local wheel.
435+
test:non_multiaccelerator_local --//jax:build_jaxlib=false
436+
437+
# `JAX_ACCELERATOR_COUNT` needs to match the number of GPUs in the VM.
438+
test:non_multiaccelerator_local --test_env=JAX_TESTS_PER_ACCELERATOR=12 --test_env=JAX_ACCELERATOR_COUNT=4
439+
440+
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
441+
# should match the VM's CPU core count (set in `--local_test_jobs`).
442+
test:non_multiaccelerator_local --local_test_jobs=48
443+
444+
# Multiaccelerator tests with all GPUs. These tests are only run locally
445+
# Disable building jaxlib. Instead we depend on the local wheel.
446+
test:multiaccelerator_local --config=gpu
447+
test:multiaccelerator_local --//jax:build_jaxlib=false
448+
test:multiaccelerator_local --jobs=8 --test_tag_filters=multiaccelerator
449+
418450
#############################################################################
419451
# Some configs to make getting some forms of debug builds. In general, the
420452
# codebase is only regularly built with optimizations. Use 'debug_symbols' to

ci/envs/default

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,40 @@ export JAXCI_DOCKER_WORK_DIR="/jax"
6464
export JAXCI_DOCKER_IMAGE=""
6565
export JAXCI_DOCKER_ARGS=""
6666

67+
# #############################################################################
68+
# Test specific environment variables.
69+
# #############################################################################
70+
71+
# Used by envs inside ci/build_artifacts. When set to 1, we disable x64 mode
72+
# and clone XLA at HEAD.
73+
export JAXCI_SETUP_TEST_ENVIRONMENT=${JAXCI_SETUP_TEST_ENVIRONMENT:-0}
74+
75+
# Set when running tests locally where we need the wheels to be installed on
76+
# the system.
77+
export JAXCI_INSTALL_WHEELS_LOCALLY=0
78+
79+
# JAXCI_PYTHON is used to install the wheels locally. It needs to match the
80+
# version of the hermetic Python used by Bazel.
81+
export JAXCI_PYTHON=python${JAXCI_HERMETIC_PYTHON_VERSION}
82+
83+
# Bazel test environment variables.
84+
export JAXCI_RUN_BAZEL_TEST_CPU=0
85+
export JAXCI_RUN_BAZEL_TEST_GPU_LOCAL=0
86+
export JAXCI_RUN_BAZEL_TEST_GPU_RBE=0
87+
88+
# Pytest environment variables.
89+
export JAXCI_RUN_PYTEST_CPU=0
90+
export JAXCI_RUN_PYTEST_GPU=0
91+
export JAXCI_RUN_PYTEST_TPU=0
92+
export JAXCI_TPU_CORES=""
93+
94+
# If set to 1, the script will clone the main XLA repository at HEAD, set its
95+
# path in JAXCI_XLA_GIT_DIR and use it to build the artifacts or run the tests.
96+
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0}
97+
98+
# Enable this globally across all builds.
99+
export JAX_SKIP_SLOW_TESTS=true
100+
67101
# #############################################################################
68102
# Variables that can be overridden by the user.
69103
# #############################################################################

ci/envs/run_tests/bazel_cpu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/envs/run_tests/bazel_gpu_local

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/envs/run_tests/bazel_gpu_rbe

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/envs/run_tests/pytest_cpu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/envs/run_tests/pytest_gpu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/envs/run_tests/pytest_tpu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
115
# Inherit default JAXCI environment variables.
216
source ci/envs/default
317

ci/run_bazel_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2024 JAX Authors. All Rights Reserved.
2+
# Copyright 2024 The JAX Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616
# Source JAXCI environment variables.
17-
source "ci/utilities/setup_envs.sh"
17+
source "ci/utilities/setup_envs.sh" "$1"
1818
# Set up the build environment.
1919
source "ci/utilities/setup_build_environment.sh"
2020

ci/run_pytest.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2024 JAX Authors. All Rights Reserved.
2+
# Copyright 2024 The JAX Authors.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616
# Source JAXCI environment variables.
17-
source "ci/utilities/setup_envs.sh"
17+
source "ci/utilities/setup_envs.sh" "$1"
1818
# Set up the build environment.
1919
source "ci/utilities/setup_build_environment.sh"
2020

@@ -40,13 +40,13 @@ fi
4040
if [[ $JAXCI_RUN_PYTEST_TPU == 1 ]]; then
4141
echo "Running TPU tests..."
4242
# Run single-accelerator tests in parallel
43-
export JAX_ENABLE_TPU_XDIST=true
43+
export JAX_ENABLE_TPU_XDIST=true
4444
check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
4545
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
4646
--maxfail=20 -m "not multiaccelerator" tests examples
47-
47+
4848
# Run Pallas printing tests, which need to run with I/O capturing disabled.
49-
export TPU_STDERR_LOG_LEVEL=0
49+
export TPU_STDERR_LOG_LEVEL=0
5050
check_if_to_run_in_docker "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
5151

5252
# Run multi-accelerator across all chips

0 commit comments

Comments
 (0)