|
64 | 64 | } |
65 | 65 |
|
66 | 66 |
|
67 | | -def add_requirements_nightly_update_argument(parser: argparse.ArgumentParser): |
68 | | - parser.add_argument( |
69 | | - "--nightly_update", |
70 | | - action="store_true", |
71 | | - help=""" |
72 | | - If true, updates requirements_lock.txt for a corresponding version of |
73 | | - Python and will consider dev, nightly and pre-release versions of |
74 | | - packages. |
75 | | - """, |
76 | | - ) |
77 | | - |
78 | | - |
79 | 67 | def add_global_arguments(parser: argparse.ArgumentParser): |
80 | 68 | """Adds all the global arguments that applies to all the CLI subcommands.""" |
81 | 69 | parser.add_argument( |
@@ -136,8 +124,8 @@ def add_global_arguments(parser: argparse.ArgumentParser): |
136 | 124 | ) |
137 | 125 |
|
138 | 126 |
|
139 | | -def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): |
140 | | - """Adds all the global arguments that applies to the artifact subcommands.""" |
| 127 | +def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser): |
| 128 | + """Adds all the arguments that applies to the artifact subcommands.""" |
141 | 129 | parser.add_argument( |
142 | 130 | "--wheels", |
143 | 131 | type=str, |
@@ -178,26 +166,32 @@ def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): |
178 | 166 | cuda_group.add_argument( |
179 | 167 | "--cuda_version", |
180 | 168 | type=str, |
181 | | - # LINT.IfChange(cuda_version) |
182 | | - default="12.3.2", |
183 | | - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) |
184 | 169 | help= |
185 | 170 | """ |
186 | 171 | Hermetic CUDA version to use. Default is to use the version specified |
187 | | - in the .bazelrc (12.3.2). |
| 172 | + in the .bazelrc. |
| 173 | + """, |
| 174 | + ) |
| 175 | + |
| 176 | + cuda_group.add_argument( |
| 177 | + "--cuda_major_version", |
| 178 | + type=str, |
| 179 | + default="12", |
| 180 | + help= |
| 181 | + """ |
| 182 | + Which CUDA major version should the wheel be tagged as? Auto-detected if |
| 183 | + --cuda_version is set. When --cuda_version is not set, the default is to |
| 184 | + set the major version to 12 to match the default in .bazelrc. |
188 | 185 | """, |
189 | 186 | ) |
190 | 187 |
|
191 | 188 | cuda_group.add_argument( |
192 | 189 | "--cudnn_version", |
193 | 190 | type=str, |
194 | | - # LINT.IfChange(cudnn_version) |
195 | | - default="9.1.1", |
196 | | - # LINT.ThenChange(//depot/google3/third_party/py/jax/oss/.bazelrc) |
197 | 191 | help= |
198 | 192 | """ |
199 | 193 | Hermetic cuDNN version to use. Default is to use the version specified |
200 | | - in the .bazelrc (9.1.1). |
| 194 | + in the .bazelrc. |
201 | 195 | """, |
202 | 196 | ) |
203 | 197 |
|
@@ -317,14 +311,22 @@ async def main(): |
317 | 311 | requirements_update_parser = subparsers.add_parser( |
318 | 312 | "requirements_update", help="Updates the requirements_lock.txt files" |
319 | 313 | ) |
320 | | - add_requirements_nightly_update_argument(requirements_update_parser) |
| 314 | + requirements_update_parser.add_argument( |
| 315 | + "--nightly_update", |
| 316 | + action="store_true", |
| 317 | + help=""" |
| 318 | + If true, updates requirements_lock.txt for a corresponding version of |
| 319 | + Python and will consider dev, nightly and pre-release versions of |
| 320 | + packages. |
| 321 | + """, |
| 322 | + ) |
321 | 323 | add_global_arguments(requirements_update_parser) |
322 | 324 |
|
323 | 325 | # Artifact build subcommand |
324 | 326 | build_artifact_parser = subparsers.add_parser( |
325 | 327 | "build", help="Builds the jaxlib, plugin, and pjrt artifact" |
326 | 328 | ) |
327 | | - add_artifact_subcommand_global_arguments(build_artifact_parser) |
| 329 | + add_artifact_subcommand_arguments(build_artifact_parser) |
328 | 330 | add_global_arguments(build_artifact_parser) |
329 | 331 |
|
330 | 332 | arch = platform.machine() |
@@ -556,7 +558,10 @@ async def main(): |
556 | 558 |
|
557 | 559 | if "cuda" in wheel: |
558 | 560 | wheel_build_command.append("--enable-cuda=True") |
559 | | - cuda_major_version = args.cuda_version.split(".")[0] |
| 561 | + if args.cuda_version: |
| 562 | + cuda_major_version = args.cuda_version.split(".")[0] |
| 563 | + else: |
| 564 | + cuda_major_version = args.cuda_major_version |
560 | 565 | wheel_build_command.append(f"--platform_version={cuda_major_version}") |
561 | 566 |
|
562 | 567 | if "rocm" in wheel: |
|
0 commit comments