Skip to content

Commit c4130cd

Browse files
committed
detect bazelrc config in the build artifacts script
1 parent 784c399 commit c4130cd

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

build/build.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,6 @@ async def main():
327327
add_global_arguments(build_artifact_parser)
328328

329329
arch = platform.machine()
330-
# Switch to lower case to match the case for the "ci_"/"rbe_" configs in the
331-
# .bazelrc.
332330
os_name = platform.system().lower()
333331

334332
args = parser.parse_args()

ci/build_artifacts.sh

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,26 @@ source ci/envs/default.env
4141
# Set up the build environment.
4242
source "ci/utilities/setup_build_environment.sh"
4343

44-
os=$(uname -s | awk '{print tolower($0)}')
4544
allowed_artifacts=("jax" "jaxlib" "jax-cuda-plugin" "jax-cuda-pjrt")
4645

46+
os=$(uname -s | awk '{print tolower($0)}')
47+
arch=$(uname -m)
48+
49+
# Adjust the values when running on Windows x86 to match the case in
50+
# .bazelrc
51+
if [[ $os =~ "msys_nt" ]] && [[ $arch == "x86_64" ]]; then
52+
os="windows"
53+
arch="amd64"
54+
fi
55+
56+
# Use "rbe_" config for Linux x86/Windows and "ci_" for other platforms
57+
bazelrc_config="$os_$arch"
58+
if ( [[ "$os" == "linux" ]] && [[ "$arch" == "x86_64" ]] ) || [[ "$os" == "windows" ]]; then
59+
bazelrc_config="rbe_$bazelrc_config"
60+
else
61+
bazelrc_config="ci_$bazelrc_config"
62+
fi
63+
4764
for artifact in "${artifacts[@]}"; do
4865

4966
if [[ " ${allowed_artifacts[@]} " =~ " ${artifact} " ]]; then
@@ -54,17 +71,17 @@ for artifact in "${artifacts[@]}"; do
5471

5572
# Build the jaxlib CPU artifact
5673
if [[ "$artifact" == "jaxlib" ]]; then
57-
python build/build.py build_artifacts --wheel_list="jaxlib" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
74+
python build/build.py build --wheels="jaxlib" --bazel_build_options=--config=$bazelrc_config --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
5875
fi
5976

6077
# Build the jax-cuda-plugin artifact
6178
if [[ "$artifact" == "jax-cuda-plugin" ]]; then
62-
python build/build.py build_artifacts --wheel_list="jax-cuda-plugin" --use_ci_bazelrc_flags --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
79+
python build/build.py build --wheels="jax-cuda-plugin" --bazel_build_options=--config="${bazelrc_config}_cuda" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
6380
fi
6481

6582
# Build the jax-cuda-pjrt artifact
6683
if [[ "$artifact" == "jax-cuda-pjrt" ]]; then
67-
python build/build.py build_artifacts --wheel_list="jax-cuda-pjrt" --use_ci_bazelrc_flags --verbose
84+
python build/build.py build --wheels="jax-cuda-pjrt" --bazel_build_options=--config="${bazelrc_config}_cuda" --verbose
6885
fi
6986

7087
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we

0 commit comments

Comments
 (0)