Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 48 additions & 43 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,32 +179,32 @@ def load_config():
sys.exit(1)
with open(CONFIG_PATH, 'r') as f: return json.load(f)

def get_gpu_info():
if sys.platform == "darwin":
try:
out = subprocess.check_output(
["system_profiler", "SPDisplaysDataType"],
encoding="utf-8",
stderr=subprocess.DEVNULL
)
for line in out.split("\n"):
if "Chip" in line:
name = line.split(":", 1)[1].strip()
return name, "APPLE"
except:
pass
return "Apple Silicon (MPS)", "APPLE"

try:
name = subprocess.check_output(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
encoding='utf-8',
stderr=subprocess.DEVNULL
).strip()
return name, "NVIDIA"
except: pass

if IS_WIN:
def get_gpu_info():
if sys.platform == "darwin":
try:
out = subprocess.check_output(
["system_profiler", "SPDisplaysDataType"],
encoding="utf-8",
stderr=subprocess.DEVNULL
)
for line in out.split("\n"):
if "Chip" in line:
name = line.split(":", 1)[1].strip()
return name, "APPLE"
except:
pass
return "Apple Silicon (MPS)", "APPLE"
try:
name = subprocess.check_output(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
encoding='utf-8',
stderr=subprocess.DEVNULL
).strip()
return name, "NVIDIA"
except: pass
if IS_WIN:
try:
name = subprocess.check_output(
"wmic path win32_VideoController get name",
Expand All @@ -230,11 +230,11 @@ def get_gpu_info():

return "Unknown", "UNKNOWN"

def get_profile_key(gpu_name, vendor):
g = gpu_name.upper()
if vendor == "APPLE":
return "MPS"
if vendor == "NVIDIA":
def get_profile_key(gpu_name, vendor):
g = gpu_name.upper()
if vendor == "APPLE":
return "MPS"
if vendor == "NVIDIA":
if "50" in g: return "RTX_50"
if "40" in g: return "RTX_40"
if "30" in g: return "RTX_30"
Expand All @@ -247,10 +247,10 @@ def get_profile_key(gpu_name, vendor):
return "AMD_GFX110X"
return "RTX_40"

def get_os_key():
if sys.platform == "darwin":
return "macos"
return "win" if IS_WIN else "linux"
def get_os_key():
if sys.platform == "darwin":
return "macos"
return "win" if IS_WIN else "linux"

def resolve_cmd(cmd_entry):
if isinstance(cmd_entry, dict):
Expand Down Expand Up @@ -370,6 +370,11 @@ def install_logic(env_name, env_type, env_path, py_k, torch_k, triton_k, sage_k,

pip = template["install"].format(dir=env_path)

# Upgrade pip in the freshly-created environment
if env_type in ["venv", "uv"]:
py_exec = template["run"].format(dir=env_path)
run_cmd(f"{py_exec} -m pip install --upgrade pip setuptools wheel")

print(f"\n[2/3] Installing Torch: {config['components']['torch'][torch_k]['label']}...")
torch_cmd = resolve_cmd(config['components']['torch'][torch_k]['cmd'])
run_cmd(f"{pip} {torch_cmd}")
Expand Down Expand Up @@ -935,12 +940,12 @@ def create_wgp_config(profile_key, config_data):

prof_settings = config_data['gpu_profiles'].get(profile_key, {})

attn_mode = prof_settings.get("attention", "")
if not attn_mode:
if "50" in profile_key or "40" in profile_key or "30" in profile_key:
attn_mode = "sage2"
elif "20" in profile_key:
attn_mode = "sage"
attn_mode = prof_settings.get("attention", "")
if not attn_mode:
if "50" in profile_key or "40" in profile_key or "30" in profile_key:
attn_mode = "sage2"
elif "20" in profile_key:
attn_mode = "sage"

compile_mode = ""
triton_key = prof_settings.get('triton')
Expand Down Expand Up @@ -1094,5 +1099,5 @@ def repair_git_repo():
else:
print("\n[*] Code is already up to date. Skipping requirements installation.")

elif args.mode == "upgrade":
do_upgrade(cfg)
elif args.mode == "upgrade":
do_upgrade(cfg)