1515# This file contains all the docker specifc envs that are needed by the
1616# ci/utilities/run_docker_container.sh script.
1717
18- # Inherit default JAXCI environment variables.
19- source ci/envs/default.env
20-
2118os = $( uname -s | awk ' {print tolower($0) }')
2219arch = $( uname -m)
2320
24- # TODO: Set GPU Docker args and GPU Docker images
25- # Linux x86 specifc settings
21+ # The path to the JAX git repository.
22+ export JAXCI_JAX_GIT_DIR = $( pwd)
23+
24+ export JAXCI_DOCKER_WORK_DIR = " /jax"
25+ export JAXCI_DOCKER_ARGS = " "
26+
27+ # Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and Bazel tests
2628if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
2729 export JAXCI_DOCKER_IMAGE = " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
2830fi
2931
30- # Linux Aarch64 specifc settings
32+ # Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and Bazel tests
3133if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
32- export JAXCI_DOCKER_IMAGE = " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container :latest"
34+ export JAXCI_DOCKER_IMAGE = " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64 :latest"
3335fi
3436
35- # Windows specific settings
37+ # Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel tests
3638if [[ $os =~ "msys_nt" ]]; then
37- export JAXCI_DOCKER_IMAGE = " gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc"
38- fi
39+ export JAXCI_DOCKER_IMAGE = " us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest"
40+ fi
41+
42+ # Uncomment the following lines if you want to run the GPU tests with Pytest.
43+ # Note that GPU Pytests, as a prequisite, require that the following JAX artifacts be
44+ # present in the $JAXCI_OUTPUT_DIR: jaxlib, jax-cuda-plugin, jax-cuda-pjrt. If you don't
45+ # have these wheels stored there, either build them from source via ci/build_artifacts.sh or
46+ # download them from PyPI into that folder.
47+ #
48+ # Linux x86 image for running Pytest GPU tests
49+ # if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
50+ # # Choose one of: 12.3, 12.1
51+ # export JAXCI_DOCKER_CUDA_VERSION=${JAXCI_DOCKER_CUDA_VERSION:-12.3}
52+ # export JAXCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/nosla-cuda${JAXCI_DOCKER_CUDA_VERSION}-cudnn9.1-ubuntu20.04-manylinux2014-multipython"
53+ #
54+ # export JAXCI_DOCKER_ARGS="--gpus all --shm-size=16g"
55+ # fi
0 commit comments