File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed
Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change 11import os
22import ctypes
3+ from importlib .metadata import distribution , PackageNotFoundError
34from pathlib import Path
45from typing import Tuple
56
@@ -62,11 +63,12 @@ def __init__(
6263
6364 # Preload Nvidia compiler runtime if available (i.e. torch is not built from source)
6465 try :
65- import nvidia .cuda_nvrtc
66- nvrtc_dir = Path (nvidia .cuda_nvrtc .__file__ ).parent .absolute ()
67- libnvrtc_path , * _ = filter (Path .is_file , (nvrtc_dir / "lib" ).glob ("libnvrtc.so.1*" ))
68- ctypes .CDLL (libnvrtc_path , ctypes .RTLD_LOCAL )
69- except ImportError :
66+ dist = distribution ("nvidia_cuda_nvrtc_cu12" )
67+ for file in dist .files :
68+ if file .name .startswith ("libnvrtc.so.1" ):
69+ ctypes .CDLL (dist .locate_file (file ), ctypes .RTLD_LOCAL )
70+ break
71+ except PackageNotFoundError :
7072 pass
7173
7274 self .madrona = MadronaBatchRenderer (
You can’t perform that action at this time.
0 commit comments