1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515# ==============================================================================
16+ <<< <<< < HEAD
1617# Runs Pyest CPU tests. Requires all jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
18+ =======
19+ # Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
20+ >>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
1721# wheels to be present inside $JAXCI_OUTPUT_DIR (../dist)
1822#
1923# -e: abort script if one command fails
2327# -o allexport: export all functions and variables to be available to subscripts
2428set -exu -o history -o allexport
2529
30+ <<< <<< < HEAD
2631# Inherit default JAXCI environment variables.
2732source ci/envs/default.env
2833
2934# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels on the system.
35+ =======
36+ # Source default JAXCI environment variables.
37+ source ci/envs/default.env
38+
39+ # Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the
40+ # $JAXCI_OUTPUT_DIR directory on the system.
41+ >>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
3042echo " Installing wheels locally..."
3143source ./ci/utilities/install_wheels_locally.sh
3244
3345# Set up the build environment.
3446source " ci/utilities/setup_build_environment.sh"
3547
48+ <<< <<< < HEAD
3649export PY_COLORS=1
3750export JAX_SKIP_SLOW_TESTS=true
3851
@@ -46,6 +59,28 @@ echo "Running GPU tests..."
4659export XLA_PYTHON_CLIENT_ALLOCATOR=platform
4760export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
4861" $JAXCI_PYTHON " -m pytest -n 8 --tb=short --maxfail=20 \
62+ =======
63+ " $JAXCI_PYTHON " -c " import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
64+
65+ nvidia-smi
66+
67+ # Set up all test environment variables
68+ export PY_COLORS=1
69+ export JAX_SKIP_SLOW_TESTS=true
70+ export NCCL_DEBUG=WARN
71+ export TF_CPP_MIN_LOG_LEVEL=0
72+
73+ # Set the number of processes to run to be 4x the number of GPUs.
74+ export gpu_count=$( nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
75+ export num_processes=` expr 4 \* $gpu_count `
76+
77+ export XLA_PYTHON_CLIENT_ALLOCATOR=platform
78+ export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
79+ # End of test environment variable setup
80+
81+ echo " Running GPU tests..."
82+ " $JAXCI_PYTHON " -m pytest -n $num_processes --tb=short --maxfail=20 \
83+ >>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
4984tests examples \
5085--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
5186--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
0 commit comments