5454 `python build/build.py requirements_update`
5555"""
5656
57- # Define the build target for each artifact .
58- ARTIFACT_BUILD_TARGET_DICT = {
57+ # Define the build target for each wheel .
58+ WHEEL_BUILD_TARGET_DICT = {
5959 "jaxlib" : "//jaxlib/tools:build_wheel" ,
6060 "jax-cuda-plugin" : "//jaxlib/tools:build_gpu_kernels_wheel" ,
6161 "jax-cuda-pjrt" : "//jaxlib/tools:build_gpu_plugin_wheel" ,
@@ -143,10 +143,11 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
143143 type = str ,
144144 default = "jaxlib" ,
145145 help =
146- f """
147- A comma separated list of JAX artifacts to build. E.g: --wheels="jaxlib",
146+ """
147+ A comma separated list of JAX wheels to build. E.g: --wheels="jaxlib",
148148 --wheels="jaxlib,jax-cuda-plugin", etc.
149- Valid options are: { ',' .join (ARTIFACT_BUILD_TARGET_DICT .keys ())}
149+ Valid options are: jaxlib, jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,
150+ jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt
150151 """ ,
151152 )
152153
@@ -362,8 +363,7 @@ async def main():
362363 f"--repo_env=HERMETIC_PYTHON_VERSION={ args .python_version } "
363364 )
364365
365- # Enable color in the Bazel output and verbose failures.
366- bazel_command_base .append ("--color=yes" )
366+ # Enable verbose failures.
367367 bazel_command_base .append ("--verbose_failures=true" )
368368
369369 # Requirements update subcommand execution
@@ -377,7 +377,7 @@ async def main():
377377 requirements_command .append (option )
378378
379379 if args .nightly_update :
380- logging .debug (
380+ logging .info (
381381 "--nightly_update is set. Bazel will run"
382382 " //build:requirements_nightly.update"
383383 )
@@ -414,8 +414,16 @@ async def main():
414414
415415 # Wheel build command execution
416416 for wheel in args .wheels .split ("," ):
417- if wheel not in ARTIFACT_BUILD_TARGET_DICT .keys ():
418- logging .error ("Incorrect wheel name provided: %s, valid choices are: %s" , wheel , "," .join (ARTIFACT_BUILD_TARGET_DICT .keys ()))
417+ # Allow CUDA/ROCm wheels without the "jax-" prefix.
418+ if ("plugin" in wheel or "pjrt" in wheel ) and "jax" not in wheel :
419+ wheel = "jax-" + wheel
420+
421+ if wheel not in WHEEL_BUILD_TARGET_DICT .keys ():
422+ logging .error (
423+ "Incorrect wheel name provided, valid choices are jaxlib,"
424+ " jax-cuda-plugin or cuda-plugin, jax-cuda-pjrt or cuda-pjrt,"
425+ " jax-rocm-plugin or rocm-plugin, jax-rocm-pjrt or rocm-pjrt"
426+ )
419427 sys .exit (1 )
420428
421429 wheel_build_command = copy .deepcopy (bazel_command_base )
@@ -518,18 +526,19 @@ async def main():
518526 for option in args .bazel_options :
519527 wheel_build_command .append (option )
520528
529+ with open (".jax_configure.bazelrc" , "w" ) as f :
530+ jax_configure_options = utils .get_jax_configure_bazel_options (wheel_build_command .get_command_as_list ())
531+ if not jax_configure_options :
532+ logging .error ("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting." )
533+ sys .exit (1 )
534+ f .write (jax_configure_options )
535+ logging .info ("Bazel options written to .jax_configure.bazelrc" )
536+
521537 if args .configure_only :
522- with open (".jax_configure.bazelrc" , "w" ) as f :
523- jax_configure_options = utils .get_jax_configure_bazel_options (wheel_build_command .get_command_as_list ())
524- if not jax_configure_options :
525- logging .error ("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting." )
526- sys .exit (1 )
527- f .write (jax_configure_options )
528- logging .info ("Bazel options written to .jax_configure.bazelrc" )
529- logging .info ("--configure_only is set so not running any Bazel commands." )
538+ logging .info ("--configure_only is set so not running any Bazel commands." )
530539 else :
531540 # Append the build target to the Bazel command.
532- build_target = ARTIFACT_BUILD_TARGET_DICT [wheel ]
541+ build_target = WHEEL_BUILD_TARGET_DICT [wheel ]
533542 wheel_build_command .append (build_target )
534543
535544 wheel_build_command .append ("--" )
@@ -538,7 +547,7 @@ async def main():
538547 logger .debug ("Artifacts output directory: %s" , output_path )
539548
540549 if args .editable :
541- logger .debug ("Building an editable build" )
550+ logger .info ("Building an editable build" )
542551 output_path = os .path .join (output_path , wheel )
543552 wheel_build_command .append ("--editable" )
544553
0 commit comments