Skip to content

Commit 1372669

Browse files
nitins17Google-ML-Automation
authored andcommitted
Add new CI script to run Bazel GPU (non-RBE) jobs
This commit adds the CI script needed for running Bazel GPU (non-RBE) tests. These run two Bazel commands: Single accelerator tests with one GPU a piece and multi-accelerator tests with all GPUs PiperOrigin-RevId: 700523594
1 parent c6866d0 commit 1372669

File tree

4 files changed

+105
-4
lines changed

4 files changed

+105
-4
lines changed

ci/envs/default.env

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,18 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist"
4242
# When enabled, artifacts will be built with RBE. Requires gcloud authentication
4343
# and only certain platforms support RBE. Therefore, this flag is enabled only
4444
# for CI builds where RBE is supported.
45-
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
45+
export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0}
46+
47+
# #############################################################################
48+
# Test script specific environment variables.
49+
# #############################################################################
50+
# The maximum number of tests to run per GPU when running single accelerator
51+
# tests with parallel execution with Bazel. The GPU limit is set because we
52+
# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we
53+
# use L4 machines which have 24GB of RAM but can be overriden if we use a
54+
# different GPU type.
55+
export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12}
56+
57+
# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override
58+
# this value in the Github action workflow files.
59+
export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0}

ci/run_bazel_test_cpu_rbe.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] )
5050
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
5151
--test_env=JAX_NUM_GENERATED_CASES=25 \
5252
--test_env=JAX_SKIP_SLOW_TESTS=true \
53-
--action_env=JAX_ENABLE_X64=0 \
53+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
5454
--test_output=errors \
5555
--color=yes \
5656
//tests:cpu_tests //tests:backend_independent_tests
@@ -61,7 +61,7 @@ else
6161
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
6262
--test_env=JAX_NUM_GENERATED_CASES=25 \
6363
--test_env=JAX_SKIP_SLOW_TESTS=true \
64-
--action_env=JAX_ENABLE_X64=0 \
64+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
6565
--test_output=errors \
6666
--color=yes \
6767
//tests:cpu_tests //tests:backend_independent_tests

ci/run_bazel_test_gpu_non_rbe.sh

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/bin/bash
2+
# Copyright 2024 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Run Bazel GPU tests without RBE. This runs two commands: single accelerator
17+
# tests with one GPU a piece, multiaccelerator tests with all GPUS.
18+
# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored
19+
# inside the ../dist folder
20+
#
21+
# -e: abort script if one command fails
22+
# -u: error if undefined variable used
23+
# -x: log all commands
24+
# -o history: record shell history
25+
# -o allexport: export all functions and variables to be available to subscripts
26+
set -exu -o history -o allexport
27+
28+
# Source default JAXCI environment variables.
29+
source ci/envs/default.env
30+
31+
# Set up the build environment.
32+
source "ci/utilities/setup_build_environment.sh"
33+
34+
# Run Bazel GPU tests (single accelerator and multiaccelerator tests) directly
35+
# on the VM without RBE.
36+
nvidia-smi
37+
echo "Running single accelerator tests (without RBE)..."
38+
39+
# Set up test environment variables.
40+
export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
41+
export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU))
42+
export num_cpu_cores=$(nproc)
43+
44+
# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores)
45+
if [[ $num_test_jobs -gt $num_cpu_cores ]]; then
46+
num_test_jobs=$num_cpu_cores
47+
fi
48+
# End of test environment variables setup.
49+
50+
# Runs single accelerator tests with one GPU apiece.
51+
# It appears --run_under needs an absolute path.
52+
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR`
53+
# should match the VM's CPU core count (set in `--local_test_jobs`).
54+
bazel test --config=ci_linux_x86_64_cuda \
55+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
56+
--//jax:build_jaxlib=false \
57+
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
58+
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \
59+
--test_output=errors \
60+
--test_env=JAX_ACCELERATOR_COUNT=$gpu_count \
61+
--test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \
62+
--local_test_jobs=$num_test_jobs \
63+
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
64+
--test_tag_filters=-multiaccelerator \
65+
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
66+
--test_env=JAX_SKIP_SLOW_TESTS=true \
67+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
68+
--action_env=NCCL_DEBUG=WARN \
69+
--color=yes \
70+
//tests:gpu_tests //tests:backend_independent_tests \
71+
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests
72+
73+
echo "Running multi-accelerator tests (without RBE)..."
74+
# Runs multiaccelerator tests with all GPUs directly on the VM without RBE..
75+
bazel test --config=ci_linux_x86_64_cuda \
76+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
77+
--//jax:build_jaxlib=false \
78+
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \
79+
--test_output=errors \
80+
--jobs=8 \
81+
--test_tag_filters=multiaccelerator \
82+
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \
83+
--test_env=JAX_SKIP_SLOW_TESTS=true \
84+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
85+
--action_env=NCCL_DEBUG=WARN \
86+
--color=yes \
87+
//tests:gpu_tests //tests/pallas:gpu_tests

ci/run_bazel_test_gpu_rbe.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ bazel test --config=rbe_linux_x86_64_cuda \
4646
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \
4747
--test_tag_filters=-multiaccelerator \
4848
--test_env=JAX_SKIP_SLOW_TESTS=true \
49-
--action_env=JAX_ENABLE_X64=0 \
49+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
5050
--color=yes \
5151
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests

0 commit comments

Comments
 (0)