Skip to content

Commit 240d284

Browse files
committed
Change how arch is detected on Windows machines
1 parent 2217b86 commit 240d284

File tree

2 files changed

+15
-17
lines changed

2 files changed

+15
-17
lines changed

build/build.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def add_build_cuda_with_clang_argument(parser: argparse.ArgumentParser):
110110
action="store_true",
111111
help="""
112112
Should CUDA code be compiled using Clang? The default behavior is to
113-
compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, we always build
114-
CUDA with NVCC in CI builds.
113+
compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, we
114+
always build CUDA with NVCC in CI builds.
115115
""",
116116
)
117117

@@ -252,8 +252,8 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
252252
type=str,
253253
default="",
254254
help="""
255-
Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags is set as we use
256-
a custom Clang toolchain in that case.
255+
Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags is
256+
set as we use a custom Clang toolchain in that case.
257257
""",
258258
)
259259

@@ -360,7 +360,9 @@ async def main():
360360
add_artifact_subcommand_global_arguments(rocm_pjrt_parser)
361361
add_global_arguments(rocm_pjrt_parser)
362362

363-
arch = platform.machine().lower()
363+
arch = platform.machine()
364+
# Switch to lower case to match the case for the "ci_"/"rbe_" configs in the
365+
# .bazelrc.
364366
os_name = platform.system().lower()
365367

366368
args = parser.parse_args()
@@ -437,7 +439,7 @@ async def main():
437439
# to Bazel to use as the C++ compiler. NVCC is used as the CUDA compiler
438440
# unless the user explicitly sets --config=build_cuda_with_clang.
439441
if args.use_ci_bazelrc_flags:
440-
bazelrc_config = utils.get_ci_bazelrc_config(os_name, arch, args.command)
442+
bazelrc_config = utils.get_ci_bazelrc_config(os_name, arch.lower(), args.command)
441443
logging.debug("Using --config=%s from .bazelrc", bazelrc_config)
442444
bazel_command.append(f"--config={bazelrc_config}")
443445
else:
@@ -477,7 +479,7 @@ async def main():
477479
"Using release cpu features: --config=avx_%s",
478480
"windows" if os_name == "windows" else "posix",
479481
)
480-
if arch == "x86_64":
482+
if arch in ["x86_64", "AMD64"]:
481483
bazel_command.append(
482484
"--config=avx_windows"
483485
if os_name == "windows"
@@ -547,9 +549,7 @@ async def main():
547549

548550
# If running on Windows, adjust the paths for compatibility.
549551
if os_name == "windows":
550-
output_path, target_cpu = utils.adjust_paths_for_windows(
551-
output_path, target_cpu
552-
)
552+
output_path = utils.adjust_paths_for_windows(output_path)
553553

554554
logger.debug("Artifacts output directory: %s", output_path)
555555

build/tools/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
logger = logging.getLogger(__name__)
3030

31+
def is_windows():
32+
return sys.platform.startswith("win32")
33+
3134
# Bazel
3235
BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/"
3336
BazelPackage = collections.namedtuple(
@@ -227,16 +230,11 @@ def get_ci_bazelrc_config(os_name: str, arch: str, artifact: str):
227230
return bazelrc_config
228231

229232

230-
def adjust_paths_for_windows(output_dir: str, arch: str) -> tuple[str, str]:
233+
def adjust_paths_for_windows(output_dir: str) -> tuple[str, str]:
231234
"""Adjusts the paths to be compatible with Windows."""
232235
logger.debug("Adjusting paths for Windows...")
233236
output_dir = output_dir.replace("/", "\\")
234-
235-
# Change to upper case to match the case in
236-
# "jax/tools/build_utils.py" for Windows.
237-
arch = arch.upper()
238-
239-
return (output_dir, arch)
237+
return output_dir
240238

241239

242240
def get_githash():

0 commit comments

Comments
 (0)