-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathinstall_torch.py
More file actions
133 lines (119 loc) · 4.49 KB
/
install_torch.py
File metadata and controls
133 lines (119 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
"""
PyTorch installation script for DeepAries project.
Automatically detects CUDA availability and installs the appropriate PyTorch version.
"""
import subprocess
import sys
import platform
def check_cuda_available():
"""Check if CUDA is available on the system."""
try:
import torch
if torch.cuda.is_available():
cuda_version = torch.version.cuda
print(f"CUDA is available: {cuda_version}")
return True, cuda_version
else:
print("CUDA is not available. Installing CPU version.")
return False, None
except ImportError:
# Try to detect CUDA without torch
try:
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
if result.returncode == 0:
print("NVIDIA GPU detected. Installing GPU version.")
# Default to CUDA 12.1 if detected
return True, "12.1"
else:
print("No NVIDIA GPU detected. Installing CPU version.")
return False, None
except FileNotFoundError:
print("nvidia-smi not found. Installing CPU version.")
return False, None
def get_pytorch_index_url(cuda_version=None):
"""Get the PyTorch index URL based on CUDA version."""
if cuda_version is None:
# CPU version
return "https://download.pytorch.org/whl/cpu"
# Map CUDA versions to PyTorch index URLs
cuda_major = cuda_version.split('.')[0] if '.' in cuda_version else cuda_version
if cuda_major == "12":
# Use CUDA 12.1 (most common)
return "https://download.pytorch.org/whl/cu121"
elif cuda_major == "11":
# Use CUDA 11.8 (most common for 11.x)
return "https://download.pytorch.org/whl/cu118"
else:
# Default to CUDA 12.1
print(f"Unknown CUDA version {cuda_version}, defaulting to CUDA 12.1")
return "https://download.pytorch.org/whl/cu121"
def install_pytorch(use_gpu=False, cuda_version=None):
"""Install PyTorch with appropriate version."""
packages = ["torch", "torchvision", "torchaudio"]
if use_gpu:
index_url = get_pytorch_index_url(cuda_version)
print(f"Installing PyTorch GPU version from {index_url}")
cmd = [
sys.executable, "-m", "pip", "install",
*packages,
"--index-url", index_url
]
else:
print("Installing PyTorch CPU version")
index_url = get_pytorch_index_url(None)
cmd = [
sys.executable, "-m", "pip", "install",
*packages,
"--index-url", index_url
]
print(f"Running: {' '.join(cmd)}")
result = subprocess.run(cmd, check=True)
return result.returncode == 0
def main():
"""Main installation function."""
print("=" * 60)
print("PyTorch Installation Script for DeepAries")
print("=" * 60)
# Check if user wants to force CPU or GPU
force_cpu = "--cpu" in sys.argv
force_gpu = "--gpu" in sys.argv
if force_cpu:
print("Forcing CPU installation (--cpu flag detected)")
use_gpu = False
cuda_version = None
elif force_gpu:
print("Forcing GPU installation (--gpu flag detected)")
use_gpu = True
cuda_version = "12.1" # Default CUDA version
else:
# Auto-detect
use_gpu, cuda_version = check_cuda_available()
# Install PyTorch
try:
success = install_pytorch(use_gpu, cuda_version)
if success:
print("\n" + "=" * 60)
print("PyTorch installation completed successfully!")
print("=" * 60)
# Verify installation
try:
import torch
print(f"\nPyTorch version: {torch.__version__}")
if use_gpu:
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU device: {torch.cuda.get_device_name(0)}")
else:
print("Using CPU version")
except ImportError:
print("Warning: Could not import torch after installation")
else:
print("PyTorch installation failed.")
sys.exit(1)
except subprocess.CalledProcessError as e:
print(f"Error installing PyTorch: {e}")
sys.exit(1)
if __name__ == "__main__":
main()