Skip to content

Commit 42253cc

Browse files
committed
update build.py
1 parent 196c3de commit 42253cc

File tree

1 file changed

+29
-20
lines changed

1 file changed

+29
-20
lines changed

build/build.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
`python build/build.py requirements_update`
5555
"""
5656

57-
# Define the build target for each artifact.
58-
ARTIFACT_BUILD_TARGET_DICT = {
57+
# Define the build target for each wheel.
58+
WHEEL_BUILD_TARGET_DICT = {
5959
"jaxlib": "//jaxlib/tools:build_wheel",
6060
"jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel",
6161
"jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel",
@@ -143,10 +143,11 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
143143
type=str,
144144
default="jaxlib",
145145
help=
146-
f"""
147-
A comma separated list of JAX artifacts to build. E.g: --wheels="jaxlib",
146+
"""
147+
A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
148148
--wheels="jaxlib,jax-cuda-plugin", etc.
149-
Valid options are: {','.join(ARTIFACT_BUILD_TARGET_DICT.keys())}
149+
Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
150+
jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
150151
""",
151152
)
152153

@@ -362,8 +363,7 @@ async def main():
362363
f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}"
363364
)
364365

365-
# Enable color in the Bazel output and verbose failures.
366-
bazel_command_base.append("--color=yes")
366+
# Enable verbose failures.
367367
bazel_command_base.append("--verbose_failures=true")
368368

369369
# Requirements update subcommand execution
@@ -377,7 +377,7 @@ async def main():
377377
requirements_command.append(option)
378378

379379
if args.nightly_update:
380-
logging.debug(
380+
logging.info(
381381
"--nightly_update is set. Bazel will run"
382382
" //build:requirements_nightly.update"
383383
)
@@ -414,8 +414,16 @@ async def main():
414414

415415
# Wheel build command execution
416416
for wheel in args.wheels.split(","):
417-
if wheel not in ARTIFACT_BUILD_TARGET_DICT.keys():
418-
logging.error("Incorrect wheel name provided: %s, valid choices are: %s", wheel, ",".join(ARTIFACT_BUILD_TARGET_DICT.keys()))
417+
# Allow CUDA/ROCm wheels without the "jax-" prefix.
418+
if ("plugin" in wheel or "pjrt" in wheel) and "jax" not in wheel:
419+
wheel = "jax-" + wheel
420+
421+
if wheel not in WHEEL_BUILD_TARGET_DICT.keys():
422+
logging.error(
423+
"Incorrect wheel name provided, valid choices are jaxlib,"
424+
" jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
425+
" jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
426+
)
419427
sys.exit(1)
420428

421429
wheel_build_command = copy.deepcopy(bazel_command_base)
@@ -518,18 +526,19 @@ async def main():
518526
for option in args.bazel_options:
519527
wheel_build_command.append(option)
520528

529+
with open(".jax_configure.bazelrc", "w") as f:
530+
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list())
531+
if not jax_configure_options:
532+
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
533+
sys.exit(1)
534+
f.write(jax_configure_options)
535+
logging.info("Bazel options written to .jax_configure.bazelrc")
536+
521537
if args.configure_only:
522-
with open(".jax_configure.bazelrc", "w") as f:
523-
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list())
524-
if not jax_configure_options:
525-
logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.")
526-
sys.exit(1)
527-
f.write(jax_configure_options)
528-
logging.info("Bazel options written to .jax_configure.bazelrc")
529-
logging.info("--configure_only is set so not running any Bazel commands.")
538+
logging.info("--configure_only is set so not running any Bazel commands.")
530539
else:
531540
# Append the build target to the Bazel command.
532-
build_target = ARTIFACT_BUILD_TARGET_DICT[wheel]
541+
build_target = WHEEL_BUILD_TARGET_DICT[wheel]
533542
wheel_build_command.append(build_target)
534543

535544
wheel_build_command.append("--")
@@ -538,7 +547,7 @@ async def main():
538547
logger.debug("Artifacts output directory: %s", output_path)
539548

540549
if args.editable:
541-
logger.debug("Building an editable build")
550+
logger.info("Building an editable build")
542551
output_path = os.path.join(output_path, wheel)
543552
wheel_build_command.append("--editable")
544553

0 commit comments

Comments
 (0)