Skip to content

Commit f8de92e

Browse files
committed
fix zluda installer bug
1 parent 32fe60b commit f8de92e

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

installer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,11 @@ def install_rocm_zluda():
629629

630630
error = None
631631
from modules import zluda_installer
632-
zluda_installer.set_default_agent(device)
633632
try:
634633
if args.reinstall or zluda_installer.is_old_zluda():
635634
zluda_installer.uninstall()
636635
zluda_installer.install()
636+
zluda_installer.set_default_agent(device)
637637
except Exception as e:
638638
error = e
639639
log.warning(f'Failed to install ZLUDA: {e}')

modules/zluda_installer.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def set_default_agent(agent: rocm.Agent):
3434
global default_agent # pylint: disable=global-statement
3535
default_agent = agent
3636

37-
is_nightly = is_nightly_zluda() or (not os.path.exists(path) and nightly)
37+
global nvcuda # pylint: disable=global-statement
38+
if nvcuda is None:
39+
nvcuda = ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll'))
40+
nvcuda.zluda_get_nightly_flag.restype = ctypes.c_int
41+
nvcuda.zluda_get_nightly_flag.argtypes = []
42+
is_nightly = nvcuda.zluda_get_nightly_flag() == 1
3843

3944
global hipBLASLt_available, hipBLASLt_enabled # pylint: disable=global-statement
4045
hipBLASLt_available = is_nightly and os.path.exists(rocm.blaslt_tensile_libpath)
@@ -44,28 +49,14 @@ def set_default_agent(agent: rocm.Agent):
4449
MIOpen_available = is_nightly and agent.gfx_version in (0x908, 0x90a, 0x940, 0x941, 0x942, 0x1030, 0x1100, 0x1101, 0x1102,)
4550

4651

47-
def load_nvcuda():
48-
global nvcuda # pylint: disable=global-statement
49-
if nvcuda is None:
50-
nvcuda = ctypes.windll.LoadLibrary(os.path.join(path, 'nvcuda.dll'))
51-
52-
53-
def is_old_zluda() -> bool: # ZLUDA<3.8.7
54-
load_nvcuda()
52+
def is_old_zluda() -> bool: # ZLUDA<3.8.8
5553
try:
5654
nvcuda.zluda_get_nightly_flag()
5755
return False
5856
except AttributeError:
5957
return True
6058

6159

62-
def is_nightly_zluda() -> bool:
63-
load_nvcuda()
64-
nvcuda.zluda_get_nightly_flag.restype = ctypes.c_int
65-
nvcuda.zluda_get_nightly_flag.argtypes = []
66-
return nvcuda.zluda_get_nightly_flag() == 1
67-
68-
6960
def install() -> None:
7061
if os.path.exists(path):
7162
return

0 commit comments

Comments
 (0)