Skip to content

Commit 784c399

Browse files
committed
update build.py
1 parent cb4a693 commit 784c399

File tree

2 files changed

+98
-163
lines changed

2 files changed

+98
-163
lines changed

build/build.py

Lines changed: 98 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
# CLI for building jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin,
18-
# jax-rocm-pjrt and for updating the requirements_lock.txt files.
17+
# CLI for building JAX wheel packages from source and for updating the
18+
# requirements_lock.txt files
1919

2020
import argparse
2121
import asyncio
@@ -44,11 +44,14 @@
4444

4545
EPILOG = """
4646
From the root directory of the JAX repository, run
47-
python build/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt]
48-
49-
to build one of: jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt
50-
or
51-
python build/build.py requirements_update to update the requirements_lock.txt
47+
`python build/build.py build --wheels=<list of JAX wheels>` to build JAX
48+
artifacts.
49+
50+
Multiple wheels can be built at the same time.
51+
E.g. python build/build.py build --wheels=jaxlib,jax-cuda-plugin
52+
53+
To update the requirements_lock.txt files, run
54+
`python build/build.py requirements_update`
5255
"""
5356

5457
# Define the build target for each artifact.
@@ -136,13 +139,13 @@ def add_global_arguments(parser: argparse.ArgumentParser):
136139
def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
137140
"""Adds all the global arguments that applies to the artifact subcommands."""
138141
parser.add_argument(
139-
"--wheel_list",
142+
"--wheels",
140143
type=str,
141144
default="jaxlib",
142145
help=
143146
f"""
144-
A comma seprated list of JAX artifacts to build. E.g: --wheel_list="jaxlib",
145-
--wheel_list="jaxlib,jax-cuda-plugin", etc.
147+
A comma separated list of JAX artifacts to build. E.g: --wheels="jaxlib",
148+
--wheels="jaxlib,jax-cuda-plugin", etc.
146149
Valid options are: {','.join(ARTIFACT_BUILD_TARGET_DICT.keys())}
147150
""",
148151
)
@@ -174,22 +177,26 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
174177
cuda_group.add_argument(
175178
"--cuda_version",
176179
type=str,
177-
default=None,
180+
# LINT.IfChange(cuda_version)
181+
default="12.3.2",
182+
# LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
178183
help=
179184
"""
180185
Hermetic CUDA version to use. Default is to use the version specified
181-
in the .bazelrc.
186+
in the .bazelrc (12.3.2).
182187
""",
183188
)
184189

185190
cuda_group.add_argument(
186191
"--cudnn_version",
187192
type=str,
188-
default=None,
193+
# LINT.IfChange(cudnn_version)
194+
default="9.1.1",
195+
# LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
189196
help=
190197
"""
191198
Hermetic cuDNN version to use. Default is to use the version specified
192-
in the .bazelrc.
199+
in the .bazelrc (9.1.1).
193200
""",
194201
)
195202

@@ -215,8 +222,7 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
215222
action="store_true",
216223
help="""
217224
Should CUDA code be compiled using Clang? The default behavior is to
218-
compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, CI
219-
builds always build CUDA with NVCC in CI builds.
225+
compile CUDA with NVCC.
220226
""",
221227
)
222228

@@ -245,33 +251,21 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
245251

246252
# Compile Options
247253
compile_group = parser.add_argument_group('Compile Options')
248-
compile_group.add_argument(
249-
"--use_ci_bazelrc_flags",
250-
action="store_true",
251-
help="""
252-
When set, the CLI will assume the build is being run in CI or CI like
253-
environment and will use the "rbe_/ci_" configs in the .bazelrc. These
254-
configs apply release features and set a custom C++ Clang toolchain.
255-
Only supported for jaxlib and CUDA builds.
256-
""",
257-
)
258254

259255
compile_group.add_argument(
260256
"--clang_path",
261257
type=str,
262258
default="",
263259
help="""
264-
Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags, CI
265-
bazelrc flags set a custom Clang toolchain.
260+
Path to the Clang binary to use.
266261
""",
267262
)
268263

269264
compile_group.add_argument(
270265
"--disable_mkl_dnn",
271266
action="store_true",
272267
help="""
273-
Disables MKL-DNN. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc
274-
flags enable MKL-DNN as default.
268+
Disables MKL-DNN.
275269
""",
276270
)
277271

@@ -285,8 +279,7 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
285279
enables AVX. Native enables -march=native, which generates code targeted
286280
to use all features of the current machine. Default means don't opt-in
287281
to any architectural features and use whatever the C compiler generates
288-
by default. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc flags
289-
enable release CPU features as default.
282+
by default.
290283
""",
291284
)
292285

@@ -306,69 +299,17 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
306299
""",
307300
)
308301

309-
def apply_compile_flags_non_ci(bazel_command: command.CommandBuilder, wheel: str, clang_path: str, disable_mkl_dnn: bool, build_cuda_with_clang: bool,\
310-
target_cpu_features: str, os_name: str, arch: str):
311-
clang_path = clang_path or utils.get_clang_path_or_exit()
312-
logging.debug("Using Clang as the compiler, clang path: %s", clang_path)
313-
# Use double quotes around clang path to avoid path issues on Windows.
314-
bazel_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
315-
bazel_command.append(f"--repo_env=CC=\"{clang_path}\"")
316-
bazel_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
317-
# Do not apply --config=clang on Mac as these settings do not apply to
318-
# Apple Clang.
319-
if os_name != "darwin":
320-
bazel_command.append("--config=clang")
321-
322-
if not disable_mkl_dnn:
323-
logging.debug("Enabling MKL DNN")
324-
bazel_command.append("--config=mkl_open_source_only")
325-
326-
if "cuda" in wheel:
327-
bazel_command.append("--config=cuda")
328-
bazel_command.append(
329-
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
330-
)
331-
if build_cuda_with_clang:
332-
logging.debug("Building CUDA with Clang")
333-
bazel_command.append("--config=build_cuda_with_clang")
334-
else:
335-
logging.debug("Building CUDA with NVCC")
336-
bazel_command.append("--config=build_cuda_with_nvcc")
337-
338-
if target_cpu_features == "release":
339-
logging.debug(
340-
"Using release cpu features: --config=avx_%s",
341-
"windows" if os_name == "windows" else "posix",
342-
)
343-
if arch in ["x86_64", "AMD64"]:
344-
bazel_command.append(
345-
"--config=avx_windows"
346-
if os_name == "windows"
347-
else "--config=avx_posix"
348-
)
349-
elif target_cpu_features == "native":
350-
if os_name == "windows":
351-
logger.warning(
352-
"--target_cpu_features=native is not supported on Windows;"
353-
" ignoring."
354-
)
355-
else:
356-
logging.debug("Using native cpu features: --config=native_arch_posix")
357-
bazel_command.append("--config=native_arch_posix")
358-
else:
359-
logging.debug("Using default cpu features")
360-
361302
async def main():
362303
parser = argparse.ArgumentParser(
363304
description=r"""
364-
CLI for building JAX wheel packages from source: jaxlib,
365-
jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt and for
366-
updating the requirements_lock.txt files
305+
CLI for building JAX wheel packages from source and for updating the
306+
requirements_lock.txt files
367307
""",
368308
epilog=EPILOG,
309+
formatter_class=argparse.RawDescriptionHelpFormatter
369310
)
370311

371-
# Create subparsers for jax, jaxlib, plugin, pjrt and requirements_update
312+
# Create subparsers for build_artifacts and requirements_update
372313
subparsers = parser.add_subparsers(dest="command", required=True)
373314

374315
# requirements_update subcommand
@@ -378,9 +319,9 @@ async def main():
378319
add_requirements_nightly_update_argument(requirements_update_parser)
379320
add_global_arguments(requirements_update_parser)
380321

381-
# Build Artifact subcommand
322+
# Artifact build subcommand
382323
build_artifact_parser = subparsers.add_parser(
383-
"build_artifacts", help="Builds the jaxlib, plugin, PJRT artifact"
324+
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
384325
)
385326
add_artifact_subcommand_global_arguments(build_artifact_parser)
386327
add_global_arguments(build_artifact_parser)
@@ -463,13 +404,6 @@ async def main():
463404
logging.debug("Local XLA path: %s", args.local_xla_path)
464405
bazel_command_base.append(f"--override_repository=xla=\"{args.local_xla_path}\"")
465406

466-
if args.bazel_build_options:
467-
logging.debug(
468-
"Additional Bazel build options: %s", args.bazel_build_options
469-
)
470-
for option in args.bazel_build_options:
471-
bazel_command_base.append(option)
472-
473407
if args.target_cpu:
474408
logging.debug("Target CPU: %s", args.target_cpu)
475409
bazel_command_base.append(f"--cpu={args.target_cpu}")
@@ -478,41 +412,75 @@ async def main():
478412
logging.debug("Disabling NCCL")
479413
bazel_command_base.append("--config=nonccl")
480414

415+
git_hash = utils.get_githash()
416+
481417
# Wheel build command execution
482-
for wheel in args.wheel_list.split(","):
418+
for wheel in args.wheels.split(","):
483419
if wheel not in ARTIFACT_BUILD_TARGET_DICT.keys():
484420
logging.error("Incorrect wheel name provided: %s, valid choices are: %s", wheel, ",".join(ARTIFACT_BUILD_TARGET_DICT.keys()))
485-
continue
421+
sys.exit(1)
422+
486423
wheel_build_command = copy.deepcopy(bazel_command_base)
487424
print("\n")
488425
logger.info(
489426
"Building %s for %s %s...",
490427
wheel,
491428
os_name,
492429
arch,
493-
)
494-
# If running in CI, we use the "ci_"/"rbe_" configs in the .bazelrc.
495-
# These set a custom C++ Clang toolchain and the CUDA compiler to NVCC
496-
# When not running in CI, we detect the path to Clang binary and pass it
497-
# to Bazel to use as the C++ compiler. NVCC is used as the CUDA compiler
498-
# unless the user explicitly sets --config=build_cuda_with_clang.
499-
if args.use_ci_bazelrc_flags and "rocm" not in wheel:
500-
bazelrc_config = utils.get_ci_bazelrc_config(os_name, arch.lower(), wheel)
501-
logging.info("--use_ci_bazelrc_flags is set, using --config=%s from .bazelrc", bazelrc_config)
502-
wheel_build_command.append(f"--config={bazelrc_config}")
503-
else:
504-
apply_compile_flags_non_ci(
505-
wheel_build_command,
506-
wheel,
507-
args.clang_path,
508-
args.disable_mkl_dnn,
509-
args.build_cuda_with_clang,
510-
args.target_cpu_features,
511-
os_name,
512-
arch,
430+
)
431+
432+
clang_path = args.clang_path or utils.get_clang_path_or_exit()
433+
logging.debug("Using Clang as the compiler, clang path: %s", clang_path)
434+
435+
# Use double quotes around clang path to avoid path issues on Windows.
436+
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
437+
wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"")
438+
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")
439+
440+
# Do not apply --config=clang on Mac as these settings do not apply to
441+
# Apple Clang.
442+
if os_name != "darwin":
443+
wheel_build_command.append("--config=clang")
444+
445+
if not args.disable_mkl_dnn:
446+
logging.debug("Enabling MKL DNN")
447+
wheel_build_command.append("--config=mkl_open_source_only")
448+
449+
if args.target_cpu_features == "release":
450+
logging.debug(
451+
"Using release cpu features: --config=avx_%s",
452+
"windows" if os_name == "windows" else "posix",
513453
)
454+
if arch in ["x86_64", "AMD64"]:
455+
wheel_build_command.append(
456+
"--config=avx_windows"
457+
if os_name == "windows"
458+
else "--config=avx_posix"
459+
)
460+
elif wheel_build_command == "native":
461+
if os_name == "windows":
462+
logger.warning(
463+
"--target_cpu_features=native is not supported on Windows;"
464+
" ignoring."
465+
)
466+
else:
467+
logging.debug("Using native cpu features: --config=native_arch_posix")
468+
wheel_build_command.append("--config=native_arch_posix")
469+
else:
470+
logging.debug("Using default cpu features")
514471

515472
if "cuda" in wheel:
473+
wheel_build_command.append("--config=cuda")
474+
wheel_build_command.append(
475+
f"--action_env=CLANG_CUDA_COMPILER_PATH=\"{clang_path}\""
476+
)
477+
if args.build_cuda_with_clang:
478+
logging.debug("Building CUDA with Clang")
479+
wheel_build_command.append("--config=build_cuda_with_clang")
480+
else:
481+
logging.debug("Building CUDA with NVCC")
482+
wheel_build_command.append("--config=build_cuda_with_nvcc")
483+
516484
if args.cuda_version:
517485
logging.debug("Hermetic CUDA version: %s", args.cuda_version)
518486
wheel_build_command.append(
@@ -543,6 +511,15 @@ async def main():
543511
f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}"
544512
)
545513

514+
# Append additional build options at the end to override any options set in
515+
# .bazelrc or above.
516+
if args.bazel_build_options:
517+
logging.debug(
518+
"Additional Bazel build options: %s", args.bazel_build_options
519+
)
520+
for option in args.bazel_build_options:
521+
wheel_build_command.append(option)
522+
546523
if args.configure_only:
547524
with open(".jax_configure.bazelrc", "w") as f:
548525
jax_configure_options = utils.get_jax_configure_bazel_options(wheel_build_command.get_command_as_list())
@@ -572,21 +549,17 @@ async def main():
572549

573550
if "cuda" in wheel:
574551
wheel_build_command.append("--enable-cuda=True")
575-
if args.cuda_version:
576-
cuda_major_version = args.cuda_version.split(".")[0]
577-
else:
578-
cuda_major_version = utils.get_cuda_major_version()
552+
cuda_major_version = args.cuda_version.split(".")[0]
579553
wheel_build_command.append(f"--platform_version={cuda_major_version}")
580554

581555
if "rocm" in wheel:
582556
wheel_build_command.append("--enable-rocm=True")
583557
wheel_build_command.append(f"--platform_version={args.rocm_version}")
584558

585-
git_hash = utils.get_githash()
586559
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
587560

588561
await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
589562

590563

591564
if __name__ == "__main__":
592-
asyncio.run(main())
565+
asyncio.run(main())

0 commit comments

Comments
 (0)