Skip to content

Commit ddd84f7

Browse files
committed
update build.py
1 parent c8714df commit ddd84f7

File tree

1 file changed

+30
-25
lines changed

1 file changed

+30
-25
lines changed

build/build.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,6 @@
6464
}
6565

6666

67-
def add_requirements_nightly_update_argument(parser: argparse.ArgumentParser):
68-
parser.add_argument(
69-
"--nightly_update",
70-
action="store_true",
71-
help="""
72-
If true, updates requirements_lock.txt for a corresponding version of
73-
Python and will consider dev, nightly and pre-release versions of
74-
packages.
75-
""",
76-
)
77-
78-
7967
def add_global_arguments(parser: argparse.ArgumentParser):
8068
"""Adds all the global arguments that applies to all the CLI subcommands."""
8169
parser.add_argument(
@@ -136,8 +124,8 @@ def add_global_arguments(parser: argparse.ArgumentParser):
136124
)
137125

138126

139-
def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
140-
"""Adds all the global arguments that applies to the artifact subcommands."""
127+
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
128+
"""Adds all the arguments that applies to the artifact subcommands."""
141129
parser.add_argument(
142130
"--wheels",
143131
type=str,
@@ -178,26 +166,32 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser):
178166
cuda_group.add_argument(
179167
"--cuda_version",
180168
type=str,
181-
# LINT.IfChange(cuda_version)
182-
default="12.3.2",
183-
# LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
184169
help=
185170
"""
186171
Hermetic CUDA version to use. Default is to use the version specified
187-
in the .bazelrc (12.3.2).
172+
in the .bazelrc.
173+
""",
174+
)
175+
176+
cuda_group.add_argument(
177+
"--cuda_major_version",
178+
type=str,
179+
default="12",
180+
help=
181+
"""
182+
Which CUDA major version should the wheel be tagged as? Auto-detected if
183+
--cuda_version is set. When --cuda_version is not set, the default is to
184+
set the major version to 12 to match the default in .bazelrc.
188185
""",
189186
)
190187

191188
cuda_group.add_argument(
192189
"--cudnn_version",
193190
type=str,
194-
# LINT.IfChange(cudnn_version)
195-
default="9.1.1",
196-
# LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc)
197191
help=
198192
"""
199193
Hermetic cuDNN version to use. Default is to use the version specified
200-
in the .bazelrc (9.1.1).
194+
in the .bazelrc.
201195
""",
202196
)
203197

@@ -317,14 +311,22 @@ async def main():
317311
requirements_update_parser = subparsers.add_parser(
318312
"requirements_update", help="Updates the requirements_lock.txt files"
319313
)
320-
add_requirements_nightly_update_argument(requirements_update_parser)
314+
requirements_update_parser.add_argument(
315+
"--nightly_update",
316+
action="store_true",
317+
help="""
318+
If true, updates requirements_lock.txt for a corresponding version of
319+
Python and will consider dev, nightly and pre-release versions of
320+
packages.
321+
""",
322+
)
321323
add_global_arguments(requirements_update_parser)
322324

323325
# Artifact build subcommand
324326
build_artifact_parser = subparsers.add_parser(
325327
"build", help="Builds the jaxlib, plugin, and pjrt artifact"
326328
)
327-
add_artifact_subcommand_global_arguments(build_artifact_parser)
329+
add_artifact_subcommand_arguments(build_artifact_parser)
328330
add_global_arguments(build_artifact_parser)
329331

330332
arch = platform.machine()
@@ -556,7 +558,10 @@ async def main():
556558

557559
if "cuda" in wheel:
558560
wheel_build_command.append("--enable-cuda=True")
559-
cuda_major_version = args.cuda_version.split(".")[0]
561+
if args.cuda_version:
562+
cuda_major_version = args.cuda_version.split(".")[0]
563+
else:
564+
cuda_major_version = args.cuda_major_version
560565
wheel_build_command.append(f"--platform_version={cuda_major_version}")
561566

562567
if "rocm" in wheel:

0 commit comments

Comments
 (0)