Skip to content

Commit f06c67e

Browse files
committed
update docker scripts
1 parent bfef9f5 commit f06c67e

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

ci/envs/docker.env

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,41 @@
1515
# This file contains all the docker specifc envs that are needed by the
1616
# ci/utilities/run_docker_container.sh script.
1717

18-
# Inherit default JAXCI environment variables.
19-
source ci/envs/default.env
20-
2118
os=$(uname -s | awk '{print tolower($0)}')
2219
arch=$(uname -m)
2320

24-
# TODO: Set GPU Docker args and GPU Docker images
25-
# Linux x86 specifc settings
21+
# The path to the JAX git repository.
22+
export JAXCI_JAX_GIT_DIR=$(pwd)
23+
24+
export JAXCI_DOCKER_WORK_DIR="/jax"
25+
export JAXCI_DOCKER_ARGS=""
26+
27+
# Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and Bazel tests
2628
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
2729
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
2830
fi
2931

30-
# Linux Aarch64 specifc settings
32+
# Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and Bazel tests
3133
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
32-
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest"
34+
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest"
3335
fi
3436

35-
# Windows specific settings
37+
# Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel tests
3638
if [[ $os =~ "msys_nt" ]]; then
37-
export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
38-
fi
39+
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest"
40+
fi
41+
42+
# Uncomment the following lines if you want to run the GPU tests with Pytest.
43+
# Note that GPU Pytests, as a prequisite, require that the following JAX artifacts be
44+
# present in the $JAXCI_OUTPUT_DIR: jaxlib, jax-cuda-plugin, jax-cuda-pjrt. If you don't
45+
# have these wheels stored there, either build them from source via ci/build_artifacts.sh or
46+
# download them from PyPI into that folder.
47+
#
48+
# Linux x86 image for running Pytest GPU tests
49+
# if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
50+
# # Choose one of: 12.3, 12.1
51+
# export JAXCI_DOCKER_CUDA_VERSION=${JAXCI_DOCKER_CUDA_VERSION:-12.3}
52+
# export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython"
53+
#
54+
# export JAXCI_DOCKER_ARGS="--gpus all --shm-size=16g"
55+
# fi

ci/utilities/run_docker_container.sh

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
# Set up the Docker container and start it for JAX CI jobs.
16+
# Sets up a Docker container for JAX CI.
17+
18+
# This script creates and starts a Docker container named "jax" for internal
19+
# JAX CI jobs. These jobs primarily handle building and publishing JAX artifacts
20+
# to PyPI and/or GCS.
21+
22+
# Note: GitHub Actions workflows do not utilize this script, as they leverage
23+
# built-in containerization features to run jobs within a container. However,
24+
# they use the same Docker image to maintain consistency. This script also helps
25+
# ensure that local build environments mirror the behavior of CI builds.
26+
# Usage:
27+
# ./ci/utilities/run_docker_container.sh
28+
# docker exec -it jax <build-script>
29+
# E.g: docker exec -it jax ./ci/build_artifacts.sh jaxlib
1730
#
1831
# -e: abort script if one command fails
1932
# -u: error if undefined variable used
@@ -46,13 +59,13 @@ if ! docker container inspect jax >/dev/null 2>&1 ; then
4659
JAXCI_DOCKER_ARGS="$JAXCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud"
4760
fi
4861

49-
# Start the container. `user_set_jaxci_envs` is read after `jax_ci_envs` to
50-
# allow the user to override any environment variables set by JAXCI_ENV_FILE.
62+
# Start the container.
5163
docker run $JAXCI_DOCKER_ARGS --name jax \
52-
-w "$JAXCI_DOCKER_WORK_DIR" -itd --rm \
53-
-v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \
54-
"$JAXCI_DOCKER_IMAGE" \
55-
bash
64+
--env-file <(env | grep JAXCI_) \
65+
-w "$JAXCI_DOCKER_WORK_DIR" -itd --rm \
66+
-v "$JAXCI_JAX_GIT_DIR:$JAXCI_DOCKER_WORK_DIR" \
67+
"$JAXCI_DOCKER_IMAGE" \
68+
bash
5669

5770
if [[ "$(uname -s)" =~ "MSYS_NT" ]]; then
5871
# Allow requests from the container.

0 commit comments

Comments
 (0)