Skip to content

Commit 43aef52

Browse files
committed
Re-factor how docker envs are set and docker container gets set up
1 parent b15621e commit 43aef52

File tree

12 files changed

+49
-110
lines changed

12 files changed

+49
-110
lines changed

ci/envs/build_artifacts/jax-cuda-pjrt.env

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,4 @@ source ci/envs/default.env
1919
export JAXCI_BUILD_PJRT="1"
2020

2121
# Enable wheel audit to check for manylinux compliance.
22-
export JAXCI_RUN_AUDITWHEEL="1"
23-
24-
os=$(uname -s | awk '{print tolower($0)}')
25-
arch=$(uname -m)
26-
27-
# Linux x86 specifc settings
28-
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
29-
# Note Python version of the container does not matter for Bazel builds and
30-
# Bazel tests. JAX supports hermetic Python and thus the actual Python version
31-
# of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`.
32-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12"
33-
fi
34-
35-
# Linux Aarch64 specifc settings
36-
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
37-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python"
38-
fi
22+
export JAXCI_RUN_AUDITWHEEL="1"

ci/envs/build_artifacts/jax-cuda-plugin.env

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,4 @@ source ci/envs/default.env
1919
export JAXCI_BUILD_PLUGIN="1"
2020

2121
# Enable wheel audit to check for manylinux compliance.
22-
export JAXCI_RUN_AUDITWHEEL="1"
23-
24-
os=$(uname -s | awk '{print tolower($0)}')
25-
arch=$(uname -m)
26-
27-
# Linux x86 specifc settings
28-
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
29-
# Note Python version of the container does not matter for Bazel builds and
30-
# Bazel tests. JAX supports hermetic Python and thus the actual Python version
31-
# of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`.
32-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12"
33-
fi
34-
35-
# Linux Aarch64 specifc settings
36-
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
37-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:tf-2-18-multi-python"
38-
fi
22+
export JAXCI_RUN_AUDITWHEEL="1"

ci/envs/build_artifacts/jax.env

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,4 @@
1616
source ci/envs/default.env
1717

1818
# Build JAX artifact.
19-
export JAXCI_BUILD_JAX="1"
20-
21-
# Note Python version of the container does not matter as `jax` is a pure
22-
# Python package.
23-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12"
19+
export JAXCI_BUILD_JAX="1"

ci/envs/build_artifacts/jaxlib.env

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,3 @@ source ci/envs/default.env
1818
# Enable jaxlib build.
1919
export JAXCI_BUILD_JAXLIB="1"
2020

21-
os=$(uname -s | awk '{print tolower($0)}')
22-
arch=$(uname -m)
23-
24-
# Linux x86 specifc settings
25-
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
26-
# Enable wheel audit to check for manylinux compliance.
27-
export JAXCI_RUN_AUDITWHEEL=1
28-
29-
# Note Python version of the container does not matter for Bazel builds and
30-
# Bazel tests. JAX supports hermetic Python and thus the actual Python version
31-
# of the artifact is controlled by the value set in `HERMETIC_PYTHON_VERSION`.
32-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:670606426-python3.12"
33-
fi
34-
35-
# Linux Aarch64 specifc settings
36-
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
37-
# Enable wheel audit to check for manylinux compliance.
38-
export JAXCI_RUN_AUDITWHEEL=1
39-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-"
40-
fi
41-
42-
# Windows specific settings
43-
if [[ $os =~ "msys_nt" ]]; then
44-
export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
45-
fi
46-
47-
# Mac specific settings
48-
if [[ $os == "macos" ]]; then
49-
# Mac builds do not run in Docker.
50-
export JAXCI_RUN_DOCKER_CONTAINER=0
51-
fi

ci/envs/docker.env

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# This file contains all the docker specifc envs that are needed by the
16+
# ci/utilities/run_docker_container.sh script.
17+
os=$(uname -s | awk '{print tolower($0)}')
18+
arch=$(uname -m)
19+
20+
# TODO: Set GPU Docker args and GPU Docker images
21+
# Linux x86 specifc settings
22+
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
23+
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
24+
fi
25+
26+
# Linux Aarch64 specifc settings
27+
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
28+
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest"
29+
fi
30+
31+
# Windows specific settings
32+
if [[ $os =~ "msys_nt" ]]; then
33+
export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
34+
fi
35+
36+
# Mac specific settings
37+
if [[ $os == "macos" ]]; then
38+
# Mac builds do not run in Docker.
39+
export JAXCI_RUN_DOCKER_CONTAINER=0
40+
fi

ci/envs/run_tests/bazel_gpu_non_rbe.env

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,3 @@ source ci/envs/default.env
1717

1818
# Enable local Bazel GPU tests
1919
export JAXCI_RUN_BAZEL_TEST_GPU_LOCAL=1
20-
21-
# Only Linux x86 runs local GPU tests at the moment.
22-
export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython"
23-
export JAXCI_DOCKER_ARGS="--shm-size=16g --gpus all"

ci/envs/run_tests/pytest_cpu.env

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,13 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
# Inherit default JAXCI environment variables.
16-
source ci/envs/default
16+
source ci/envs/default.env
1717

1818
# Enable CPU Pytests
1919
export JAXCI_RUN_PYTEST_CPU=1
2020

2121
# Install jaxlib wheel locally.
2222
export JAXCI_INSTALL_WHEELS_LOCALLY=1
2323

24-
# Disable x64 mode
25-
export JAX_ENABLE_X64=0
26-
2724
# Clone XLA at HEAD.
2825
export JAXCI_CLONE_MAIN_XLA=1
29-
30-
export TF_CPP_MIN_LOG_LEVEL=0

ci/envs/run_tests/pytest_gpu.env

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,5 @@ export JAXCI_RUN_PYTEST_GPU=1
2121
# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels locally.
2222
export JAXCI_INSTALL_WHEELS_LOCALLY=1
2323

24-
# Only Linux x86 runs local GPU tests at the moment.
25-
export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython"
26-
export JAXCI_DOCKER_ARGS="--shm-size=16g --gpus all"
27-
28-
# TODO(srnitin): Figure out where this gets used
29-
export JAX_CUDA_VERSION=12
30-
export JAX_CUDA_FULL_VERSION=12.3
31-
export JAX_DOCKER_CUDA_FULL_VERSION=12.1
32-
export JAX_CUDNN_VERSION=9.1
33-
export JAX_CUDA_PLUGIN='True'
34-
35-
# Disable x64 mode
36-
export JAX_ENABLE_X64=0
37-
3824
# Clone XLA at HEAD.
3925
export JAXCI_CLONE_MAIN_XLA=1

ci/envs/run_tests/pytest_tpu.env

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,5 @@ export JAXCI_RUN_PYTEST_TPU=1
2121

2222
export JAXCI_TPU_CORES=8
2323

24-
# Disable x64 mode
25-
export JAX_ENABLE_X64=0
26-
2724
# Clone XLA at HEAD.
2825
export JAXCI_CLONE_MAIN_XLA=1

ci/run_pytest.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ source "ci/utilities/setup_jaxci_envs.sh" "$1"
1818
# Set up the build environment.
1919
source "ci/utilities/setup_build_environment.sh"
2020

21+
export JAX_SKIP_SLOW_TESTS=true
22+
export JAX_ENABLE_X64=0
23+
2124
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
2225

2326
if [[ $JAXCI_RUN_PYTEST_CPU == 1 ]]; then

0 commit comments

Comments
 (0)