Skip to content

Commit e26333f

Browse files
committed
Make cuda args to be None by default and add function to retrive cuda version from .bazelrc
1 parent d72f55e commit e26333f

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

build/build.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def add_cuda_version_argument(parser: argparse.ArgumentParser):
7373
parser.add_argument(
7474
"--cuda_version",
7575
type=str,
76-
default="12.3.2",
76+
default=None,
7777
help="Hermetic CUDA version to use",
7878
)
7979

@@ -82,7 +82,7 @@ def add_cudnn_version_argument(parser: argparse.ArgumentParser):
8282
parser.add_argument(
8383
"--cudnn_version",
8484
type=str,
85-
default="9.1.1",
85+
default=None,
8686
help="Hermetic cuDNN version to use",
8787
)
8888

@@ -577,7 +577,10 @@ async def main():
577577

578578
if "cuda" in args.command:
579579
bazel_command.append("--enable-cuda=True")
580-
cuda_major_version = args.cuda_version.split(".")[0]
580+
if args.cuda_version:
581+
cuda_major_version = args.cuda_version.split(".")[0]
582+
else:
583+
cuda_major_version = utils.get_cuda_major_version()
581584
bazel_command.append(f"--platform_version={cuda_major_version}")
582585

583586
if "rocm" in args.command:

build/tools/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ def get_clang_path_or_exit():
192192
)
193193
sys.exit(-1)
194194

195+
def get_cuda_major_version():
196+
"""Extract the CUDA major version from the .bazelrc"""
197+
with open(".bazelrc", "r") as f:
198+
for line in f:
199+
match = re.search(r'HERMETIC_CUDA_VERSION="([^"]+)"', line)
200+
if match:
201+
cuda_version=match.group(1)
202+
return cuda_version.split(".")[0]
203+
return None
204+
195205

196206
def get_bazelrc_config(os_name: str, arch: str, artifact: str, request_rbe: bool):
197207
"""Returns the bazelrc config for the given architecture and OS.
@@ -202,8 +212,8 @@ def get_bazelrc_config(os_name: str, arch: str, artifact: str, request_rbe: bool
202212

203213
bazelrc_config = f"{os_name}_{arch}"
204214

205-
# If a build is requesting RBE, the CLI will use RBE if the host system
206-
# supports it, otherwise it will use the "ci_" (non RBE) config.
215+
# If a build is requesting RBE, the CLI will use RBE if the host system supports
216+
# it, otherwise it will use the "ci_" (non RBE) config.
207217
if request_rbe:
208218
if (os_name == "linux" and arch == "x86_64") or (
209219
os_name == "windows" and arch == "amd64"

0 commit comments

Comments
 (0)