1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616#
17- # CLI for building jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin,
18- # jax-rocm-pjrt and for updating the requirements_lock.txt files.
17+ # CLI for building JAX wheel packages from source and for updating the
18+ # requirements_lock.txt files
1919
2020import argparse
2121import asyncio
4444
4545EPILOG = """
4646From the root directory of the JAX repository, run
47- python build/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt]
48-
49- to build one of: jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt
50- or
51- python build/build.py requirements_update to update the requirements_lock.txt
47+ `python build/build.py build --wheels=<list of JAX wheels>` to build JAX
48+ artifacts.
49+
50+ Multiple wheels can be built at the same time.
51+ E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
52+
53+ To update the requirements_lock.txt files, run
54+ `python build/build.py requirements_update`
5255"""
5356
5457# Define the build target for each artifact.
@@ -136,13 +139,13 @@ def add_global_arguments(parser: argparse.ArgumentParser):
136139def add_artifact_subcommand_global_arguments (parser : argparse .ArgumentParser ):
137140 """Adds all the global arguments that applies to the artifact subcommands."""
138141 parser .add_argument (
139- "--wheel_list " ,
142+ "--wheels " ,
140143 type = str ,
141144 default = "jaxlib" ,
142145 help =
143146 f"""
144- A comma seprated list of JAX artifacts to build. E.g: --wheel_list ="jaxlib",
145- --wheel_list ="jaxlib,jax-cuda-plugin", etc.
147+ A comma separated list of JAX artifacts to build. E.g: --wheels ="jaxlib",
148+ --wheels ="jaxlib,jax-cuda-plugin", etc.
146149 Valid options are: { ',' .join (ARTIFACT_BUILD_TARGET_DICT .keys ())}
147150 """ ,
148151 )
@@ -174,22 +177,26 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
174177 cuda_group .add_argument (
175178 "--cuda_version" ,
176179 type = str ,
177- default = None ,
180+ # LINT.IfChange(cuda_version)
181+ default = "12.3.2" ,
182+ # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
178183 help =
179184 """
180185 Hermetic CUDA version to use. Default is to use the version specified
181- in the .bazelrc.
186+ in the .bazelrc (12.3.2) .
182187 """ ,
183188 )
184189
185190 cuda_group .add_argument (
186191 "--cudnn_version" ,
187192 type = str ,
188- default = None ,
193+ # LINT.IfChange(cudnn_version)
194+ default = "9.1.1" ,
195+ # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
189196 help =
190197 """
191198 Hermetic cuDNN version to use. Default is to use the version specified
192- in the .bazelrc.
199+ in the .bazelrc (9.1.1) .
193200 """ ,
194201 )
195202
@@ -215,8 +222,7 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
215222 action = "store_true" ,
216223 help = """
217224 Should CUDA code be compiled using Clang? The default behavior is to
218- compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, CI
219- builds always build CUDA with NVCC in CI builds.
225+ compile CUDA with NVCC.
220226 """ ,
221227 )
222228
@@ -245,33 +251,21 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
245251
246252 # Compile Options
247253 compile_group = parser .add_argument_group ('Compile Options' )
248- compile_group .add_argument (
249- "--use_ci_bazelrc_flags" ,
250- action = "store_true" ,
251- help = """
252- When set, the CLI will assume the build is being run in CI or CI like
253- environment and will use the "rbe_/ci_" configs in the .bazelrc. These
254- configs apply release features and set a custom C++ Clang toolchain.
255- Only supported for jaxlib and CUDA builds.
256- """ ,
257- )
258254
259255 compile_group .add_argument (
260256 "--clang_path" ,
261257 type = str ,
262258 default = "" ,
263259 help = """
264- Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags, CI
265- bazelrc flags set a custom Clang toolchain.
260+ Path to the Clang binary to use.
266261 """ ,
267262 )
268263
269264 compile_group .add_argument (
270265 "--disable_mkl_dnn" ,
271266 action = "store_true" ,
272267 help = """
273- Disables MKL-DNN. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc
274- flags enable MKL-DNN as default.
268+ Disables MKL-DNN.
275269 """ ,
276270 )
277271
@@ -285,8 +279,7 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
285279 enables AVX. Native enables -march=native, which generates code targeted
286280 to use all features of the current machine. Default means don't opt-in
287281 to any architectural features and use whatever the C compiler generates
288- by default. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc flags
289- enable release CPU features as default.
282+ by default.
290283 """ ,
291284 )
292285
@@ -306,69 +299,17 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
306299 """ ,
307300 )
308301
309- def apply_compile_flags_non_ci (bazel_command : command .CommandBuilder , wheel : str , clang_path : str , disable_mkl_dnn : bool , build_cuda_with_clang : bool ,\
310- target_cpu_features : str , os_name : str , arch : str ):
311- clang_path = clang_path or utils .get_clang_path_or_exit ()
312- logging .debug ("Using Clang as the compiler, clang path: %s" , clang_path )
313- # Use double quotes around clang path to avoid path issues on Windows.
314- bazel_command .append (f"--action_env=CLANG_COMPILER_PATH=\" { clang_path } \" " )
315- bazel_command .append (f"--repo_env=CC=\" { clang_path } \" " )
316- bazel_command .append (f"--repo_env=BAZEL_COMPILER=\" { clang_path } \" " )
317- # Do not apply --config=clang on Mac as these settings do not apply to
318- # Apple Clang.
319- if os_name != "darwin" :
320- bazel_command .append ("--config=clang" )
321-
322- if not disable_mkl_dnn :
323- logging .debug ("Enabling MKL DNN" )
324- bazel_command .append ("--config=mkl_open_source_only" )
325-
326- if "cuda" in wheel :
327- bazel_command .append ("--config=cuda" )
328- bazel_command .append (
329- f"--action_env=CLANG_CUDA_COMPILER_PATH=\" { clang_path } \" "
330- )
331- if build_cuda_with_clang :
332- logging .debug ("Building CUDA with Clang" )
333- bazel_command .append ("--config=build_cuda_with_clang" )
334- else :
335- logging .debug ("Building CUDA with NVCC" )
336- bazel_command .append ("--config=build_cuda_with_nvcc" )
337-
338- if target_cpu_features == "release" :
339- logging .debug (
340- "Using release cpu features: --config=avx_%s" ,
341- "windows" if os_name == "windows" else "posix" ,
342- )
343- if arch in ["x86_64" , "AMD64" ]:
344- bazel_command .append (
345- "--config=avx_windows"
346- if os_name == "windows"
347- else "--config=avx_posix"
348- )
349- elif target_cpu_features == "native" :
350- if os_name == "windows" :
351- logger .warning (
352- "--target_cpu_features=native is not supported on Windows;"
353- " ignoring."
354- )
355- else :
356- logging .debug ("Using native cpu features: --config=native_arch_posix" )
357- bazel_command .append ("--config=native_arch_posix" )
358- else :
359- logging .debug ("Using default cpu features" )
360-
361302async def main ():
362303 parser = argparse .ArgumentParser (
363304 description = r"""
364- CLI for building JAX wheel packages from source: jaxlib,
365- jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt and for
366- updating the requirements_lock.txt files
305+ CLI for building JAX wheel packages from source and for updating the
306+ requirements_lock.txt files
367307 """ ,
368308 epilog = EPILOG ,
309+ formatter_class = argparse .RawDescriptionHelpFormatter
369310 )
370311
371- # Create subparsers for jax, jaxlib, plugin, pjrt and requirements_update
312+ # Create subparsers for build_artifacts and requirements_update
372313 subparsers = parser .add_subparsers (dest = "command" , required = True )
373314
374315 # requirements_update subcommand
@@ -378,9 +319,9 @@ async def main():
378319 add_requirements_nightly_update_argument (requirements_update_parser )
379320 add_global_arguments (requirements_update_parser )
380321
381- # Build Artifact subcommand
322+ # Artifact build subcommand
382323 build_artifact_parser = subparsers .add_parser (
383- "build_artifacts " , help = "Builds the jaxlib, plugin, PJRT artifact"
324+ "build " , help = "Builds the jaxlib, plugin, and pjrt artifact"
384325 )
385326 add_artifact_subcommand_global_arguments (build_artifact_parser )
386327 add_global_arguments (build_artifact_parser )
@@ -463,13 +404,6 @@ async def main():
463404 logging .debug ("Local XLA path: %s" , args .local_xla_path )
464405 bazel_command_base .append (f"--override_repository=xla=\" { args .local_xla_path } \" " )
465406
466- if args .bazel_build_options :
467- logging .debug (
468- "Additional Bazel build options: %s" , args .bazel_build_options
469- )
470- for option in args .bazel_build_options :
471- bazel_command_base .append (option )
472-
473407 if args .target_cpu :
474408 logging .debug ("Target CPU: %s" , args .target_cpu )
475409 bazel_command_base .append (f"--cpu={ args .target_cpu } " )
@@ -478,41 +412,75 @@ async def main():
478412 logging .debug ("Disabling NCCL" )
479413 bazel_command_base .append ("--config=nonccl" )
480414
415+ git_hash = utils .get_githash ()
416+
481417 # Wheel build command execution
482- for wheel in args .wheel_list .split ("," ):
418+ for wheel in args .wheels .split ("," ):
483419 if wheel not in ARTIFACT_BUILD_TARGET_DICT .keys ():
484420 logging .error ("Incorrect wheel name provided: %s, valid choices are: %s" , wheel , "," .join (ARTIFACT_BUILD_TARGET_DICT .keys ()))
485- continue
421+ sys .exit (1 )
422+
486423 wheel_build_command = copy .deepcopy (bazel_command_base )
487424 print ("\n " )
488425 logger .info (
489426 "Building %s for %s %s..." ,
490427 wheel ,
491428 os_name ,
492429 arch ,
493- )
494- # If running in CI, we use the "ci_"/"rbe_" configs in the .bazelrc.
495- # These set a custom C++ Clang toolchain and the CUDA compiler to NVCC
496- # When not running in CI, we detect the path to Clang binary and pass it
497- # to Bazel to use as the C++ compiler. NVCC is used as the CUDA compiler
498- # unless the user explicitly sets --config=build_cuda_with_clang.
499- if args .use_ci_bazelrc_flags and "rocm" not in wheel :
500- bazelrc_config = utils .get_ci_bazelrc_config (os_name , arch .lower (), wheel )
501- logging .info ("--use_ci_bazelrc_flags is set, using --config=%s from .bazelrc" , bazelrc_config )
502- wheel_build_command .append (f"--config={ bazelrc_config } " )
503- else :
504- apply_compile_flags_non_ci (
505- wheel_build_command ,
506- wheel ,
507- args .clang_path ,
508- args .disable_mkl_dnn ,
509- args .build_cuda_with_clang ,
510- args .target_cpu_features ,
511- os_name ,
512- arch ,
430+ )
431+
432+ clang_path = args .clang_path or utils .get_clang_path_or_exit ()
433+ logging .debug ("Using Clang as the compiler, clang path: %s" , clang_path )
434+
435+ # Use double quotes around clang path to avoid path issues on Windows.
436+ wheel_build_command .append (f"--action_env=CLANG_COMPILER_PATH=\" { clang_path } \" " )
437+ wheel_build_command .append (f"--repo_env=CC=\" { clang_path } \" " )
438+ wheel_build_command .append (f"--repo_env=BAZEL_COMPILER=\" { clang_path } \" " )
439+
440+ # Do not apply --config=clang on Mac as these settings do not apply to
441+ # Apple Clang.
442+ if os_name != "darwin" :
443+ wheel_build_command .append ("--config=clang" )
444+
445+ if not args .disable_mkl_dnn :
446+ logging .debug ("Enabling MKL DNN" )
447+ wheel_build_command .append ("--config=mkl_open_source_only" )
448+
449+ if args .target_cpu_features == "release" :
450+ logging .debug (
451+ "Using release cpu features: --config=avx_%s" ,
452+ "windows" if os_name == "windows" else "posix" ,
513453 )
454+ if arch in ["x86_64" , "AMD64" ]:
455+ wheel_build_command .append (
456+ "--config=avx_windows"
457+ if os_name == "windows"
458+ else "--config=avx_posix"
459+ )
460+ elif wheel_build_command == "native" :
461+ if os_name == "windows" :
462+ logger .warning (
463+ "--target_cpu_features=native is not supported on Windows;"
464+ " ignoring."
465+ )
466+ else :
467+ logging .debug ("Using native cpu features: --config=native_arch_posix" )
468+ wheel_build_command .append ("--config=native_arch_posix" )
469+ else :
470+ logging .debug ("Using default cpu features" )
514471
515472 if "cuda" in wheel :
473+ wheel_build_command .append ("--config=cuda" )
474+ wheel_build_command .append (
475+ f"--action_env=CLANG_CUDA_COMPILER_PATH=\" { clang_path } \" "
476+ )
477+ if args .build_cuda_with_clang :
478+ logging .debug ("Building CUDA with Clang" )
479+ wheel_build_command .append ("--config=build_cuda_with_clang" )
480+ else :
481+ logging .debug ("Building CUDA with NVCC" )
482+ wheel_build_command .append ("--config=build_cuda_with_nvcc" )
483+
516484 if args .cuda_version :
517485 logging .debug ("Hermetic CUDA version: %s" , args .cuda_version )
518486 wheel_build_command .append (
@@ -543,6 +511,15 @@ async def main():
543511 f"--action_env=TF_ROCM_AMDGPU_TARGETS={ args .rocm_amdgpu_targets } "
544512 )
545513
514+ # Append additional build options at the end to override any options set in
515+ # .bazelrc or above.
516+ if args .bazel_build_options :
517+ logging .debug (
518+ "Additional Bazel build options: %s" , args .bazel_build_options
519+ )
520+ for option in args .bazel_build_options :
521+ wheel_build_command .append (option )
522+
546523 if args .configure_only :
547524 with open (".jax_configure.bazelrc" , "w" ) as f :
548525 jax_configure_options = utils .get_jax_configure_bazel_options (wheel_build_command .get_command_as_list ())
@@ -572,21 +549,17 @@ async def main():
572549
573550 if "cuda" in wheel :
574551 wheel_build_command .append ("--enable-cuda=True" )
575- if args .cuda_version :
576- cuda_major_version = args .cuda_version .split ("." )[0 ]
577- else :
578- cuda_major_version = utils .get_cuda_major_version ()
552+ cuda_major_version = args .cuda_version .split ("." )[0 ]
579553 wheel_build_command .append (f"--platform_version={ cuda_major_version } " )
580554
581555 if "rocm" in wheel :
582556 wheel_build_command .append ("--enable-rocm=True" )
583557 wheel_build_command .append (f"--platform_version={ args .rocm_version } " )
584558
585- git_hash = utils .get_githash ()
586559 wheel_build_command .append (f"--jaxlib_git_hash={ git_hash } " )
587560
588561 await executor .run (wheel_build_command .get_command_as_string (), args .dry_run )
589562
590563
591564if __name__ == "__main__" :
592- asyncio .run (main ())
565+ asyncio .run (main ())
0 commit comments