Skip to content

Commit 59ffcca

Browse files
committed
Format files
1 parent 3c5fd4a commit 59ffcca

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

tools/cuda_utils.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def prepare_cuda_env(cuda_version: str, dryrun=False):
3737
env["CUDA_HOME"] = cuda_path_str
3838
env["PATH"] = f"{cuda_path_str}/bin:{env['PATH']}"
3939
env["CMAKE_CUDA_COMPILER"] = str(cuda_path.joinpath("bin", "nvcc").resolve())
40-
env["LD_LIBRARY_PATH"] = (
41-
f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}"
42-
)
40+
env[
41+
"LD_LIBRARY_PATH"
42+
] = f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}"
4343
if dryrun:
4444
print(f"CUDA_HOME is set to {env['CUDA_HOME']}")
4545
# step 2: test call to nvcc to confirm the version is correct
@@ -88,6 +88,7 @@ def setup_cuda_softlink(cuda_version: str):
8888

8989
def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
9090
from .torch_utils import TORCH_NIGHTLY_PACKAGES
91+
9192
uninstall_torch_cmd = ["pip", "uninstall", "-y"]
9293
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
9394
if dryrun:
@@ -168,11 +169,15 @@ def install_torch_deps(cuda_version: str):
168169
install_torch_deps(cuda_version=args.cudaver)
169170
if args.install_torch_build_deps:
170171
from .torch_utils import install_torch_build_deps
172+
171173
install_torch_deps(cuda_version=args.cudaver)
172174
install_torch_build_deps()
173175
if args.install_torch_nightly:
174176
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
175177
if args.check_torch_nightly_version:
176178
from .torch_utils import check_torch_nightly_version
177-
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
179+
180+
assert (
181+
not args.install_torch_nightly
182+
), "Error: Can't run install torch nightly and check version in the same command."
178183
check_torch_nightly_version(args.force_date)

tools/rocm_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def install_torch_deps():
3434

3535
def install_pytorch_nightly(rocm_version: str, env, dryrun=False):
3636
from .torch_utils import TORCH_NIGHTLY_PACKAGES
37+
3738
uninstall_torch_cmd = ["pip", "uninstall", "-y"]
3839
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
3940
if dryrun:

tools/torch_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
CUDA/ROCM independent pytorch installation helpers.
33
"""
44

5-
import subprocess
65
import importlib
76
import re
7+
import subprocess
88
from pathlib import Path
99

1010
from typing import Optional
@@ -18,9 +18,11 @@
1818

1919
def is_hip() -> bool:
2020
import torch
21+
2122
version = torch.__version__
2223
return "rocm" in version
2324

25+
2426
def install_torch_build_deps():
2527
# Pin cmake version to stable
2628
# See: https://github.com/pytorch/builder/pull/1269
@@ -52,6 +54,7 @@ def install_torch_build_deps():
5254
cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps
5355
subprocess.check_call(cmd)
5456

57+
5558
def get_torch_nightly_version(pkg_name: str):
5659
pkg = importlib.import_module(pkg_name)
5760
version = pkg.__version__
@@ -75,4 +78,4 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
7578
force_date_str = f"User force date {force_date}" if force_date else ""
7679
print(
7780
f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}"
78-
)
81+
)

0 commit comments

Comments
 (0)