@@ -41,9 +41,26 @@ source ci/envs/default.env
4141# Set up the build environment.
4242source " ci/utilities/setup_build_environment.sh"
4343
44- os=$( uname -s | awk ' {print tolower($0)}' )
4544allowed_artifacts=(" jax" " jaxlib" " jax-cuda-plugin" " jax-cuda-pjrt" )
4645
46+ os=$( uname -s | awk ' {print tolower($0)}' )
47+ arch=$( uname -m)
48+
49+ # Adjust the values when running on Windows x86 to match the case in
50+ # .bazelrc
51+ if [[ $os =~ " msys_nt" ]] && [[ $arch == " x86_64" ]]; then
52+ os=" windows"
53+ arch=" amd64"
54+ fi
55+
56+ # Use "rbe_" config for Linux x86/Windows and "ci_" for other platforms
57+ bazelrc_config=" $os_$arch "
58+ if ( [[ " $os " == " linux" ]] && [[ " $arch " == " x86_64" ]] ) || [[ " $os " == " windows" ]]; then
59+ bazelrc_config=" rbe_$bazelrc_config "
60+ else
61+ bazelrc_config=" ci_$bazelrc_config "
62+ fi
63+
4764for artifact in " ${artifacts[@]} " ; do
4865
4966 if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
@@ -54,17 +71,17 @@ for artifact in "${artifacts[@]}"; do
5471
5572 # Build the jaxlib CPU artifact
5673 if [[ " $artifact " == " jaxlib" ]]; then
57- python build/build.py build_artifacts --wheel_list =" jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
74+ python build/build.py build --wheels =" jaxlib" --bazel_build_options=--config= $bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
5875 fi
5976
6077 # Build the jax-cuda-plugin artifact
6178 if [[ " $artifact " == " jax-cuda-plugin" ]]; then
62- python build/build.py build_artifacts --wheel_list =" jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
79+ python build/build.py build --wheels =" jax-cuda-plugin" --bazel_build_options=--config= " ${bazelrc_config} _cuda " --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
6380 fi
6481
6582 # Build the jax-cuda-pjrt artifact
6683 if [[ " $artifact " == " jax-cuda-pjrt" ]]; then
67- python build/build.py build_artifacts --wheel_list =" jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
84+ python build/build.py build --wheels =" jax-cuda-pjrt" --bazel_build_options=--config= " ${bazelrc_config} _cuda " --verbose
6885 fi
6986
7087 # If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
0 commit comments