-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Adding optional CUDA DLLs when installing onnxruntime_gpu #22506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
65b0f6b
a335dc6
87c51fb
c95dbce
990752e
e15e3e0
5e70dd0
3bf6817
faa6e3a
2efad16
f4f9d35
c025719
c7d0951
5fea4d4
ca752b6
f27a566
120ddf9
644b52e
f557c1e
f170664
5a0e3fb
cddc500
81ef596
05e8441
513522f
77ac10a
fa785d7
cf7ba65
d24c96c
e9b913f
144f066
0560de9
053d0a3
457f4a2
014833f
d2cbf27
02c73c6
cd02bdb
b342c18
11b2604
74b8a91
5ddfc0f
287bd46
dff876c
5923d1b
cf612fc
d0ffbaf
ad0cf6b
fad62cb
f2aa262
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -76,3 +76,97 @@ | |||||||
| __version__ = version | ||||||||
|
|
||||||||
| onnxruntime_validation.check_distro_info() | ||||||||
|
|
||||||||
|
|
||||||||
| def check_and_load_cuda_libs(root_directory, cuda_libs_): | ||||||||
| # Convert the target library names to lowercase for case-insensitive comparison | ||||||||
| # Convert the target library names to lowercase for case-insensitive comparison | ||||||||
| if cuda_libs_ is None or len(cuda_libs_) == 0: | ||||||||
| logging.info("No CUDA libraries provided for loading.") | ||||||||
| return | ||||||||
| cuda_libs_ = {lib.lower() for lib in cuda_libs_} | ||||||||
| found_libs = {} | ||||||||
| for dirpath, _, filenames in os.walk(root_directory): | ||||||||
| # Convert filenames in the current directory to lowercase for comparison | ||||||||
| files_in_dir = {file.lower(): file for file in filenames} # Map lowercase to original | ||||||||
| # Find common libraries in the current directory | ||||||||
| matched_libs = cuda_libs_.intersection(files_in_dir.keys()) | ||||||||
| for lib in matched_libs: | ||||||||
| # Store the full path of the found DLL | ||||||||
| full_path = os.path.join(dirpath, files_in_dir[lib]) | ||||||||
| found_libs[lib] = full_path | ||||||||
| try: | ||||||||
| # Load the DLL using ctypes | ||||||||
| _ = ctypes.CDLL(full_path) | ||||||||
| logging.info(f"Successfully loaded: {full_path}") | ||||||||
| except OSError as e: | ||||||||
| logging.info(f"Failed to load {full_path}: {e}") | ||||||||
|
|
||||||||
| # If all required libraries are found, stop the search | ||||||||
| if set(found_libs.keys()) == cuda_libs_: | ||||||||
| logging.info("All required CUDA libraries found and loaded.") | ||||||||
| return | ||||||||
| logging.info( | ||||||||
| f"Failed to load CUDA libraries from site-packages/nvidia directory: {cuda_libs_ - found_libs.keys()}. They might be loaded later from standard search paths for shared libraries." | ||||||||
| ) | ||||||||
| return | ||||||||
|
|
||||||||
|
|
||||||||
| # Load nvidia libraries from site-packages/nvidia if the package is onnxruntime-gpu | ||||||||
jchen351 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| if cuda_version is not None and cuda_version != "": | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my test, cuda_version is still empty string. It is imported from onnxruntime.capi.onnxruntime_validation in line 73. That class only outputs cuda_version for training as below: onnxruntime/onnxruntime/python/onnxruntime_validation.py Lines 100 to 102 in 29bccad
We can remove the line of if has_ortmodule there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it mean the following code usually won't get executed? |
||||||||
| import ctypes | ||||||||
| import logging | ||||||||
| import os | ||||||||
| import platform | ||||||||
| import site | ||||||||
|
|
||||||||
| cuda_version_ = tuple(map(int, cuda_version.split("."))) | ||||||||
| # Get the site-packages path where nvidia packages are installed | ||||||||
| site_packages_path = site.getsitepackages()[-1] | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we check all sitepackages directories? We might also check whether things like |
||||||||
| nvidia_path = os.path.join(site_packages_path, "nvidia") | ||||||||
| # Traverse the directory and subdirectories | ||||||||
| cuda_libs = () | ||||||||
| if platform.system() == "Windows": # | ||||||||
| # Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows | ||||||||
| if (11, 0) <= cuda_version_ < (12, 0): | ||||||||
| cuda_libs = ( | ||||||||
| "cublaslt64_11.dll", | ||||||||
| "cublas64_11.dll", | ||||||||
| "cufft64_10.dll", | ||||||||
| "cudart64_11.dll", | ||||||||
| "cudnn64_8.dll", | ||||||||
| ) | ||||||||
| elif (12, 0) <= cuda_version_ < (13, 0): | ||||||||
| cuda_libs = ( | ||||||||
| "cublaslt64_12.dll", | ||||||||
| "cublas64_12.dll", | ||||||||
| "cufft64_11.dll", | ||||||||
| "cudart64_12.dll", | ||||||||
| "cudnn64_9.dll", | ||||||||
| ) | ||||||||
| elif platform.system() == "Linux": | ||||||||
| if (11, 0) <= cuda_version_ < (12, 0): | ||||||||
| # Define the patterns with optional version number and case-insensitivity | ||||||||
| cuda_libs = ( | ||||||||
| "libcublaslt.so.11", | ||||||||
| "libcublas.so.11", | ||||||||
| "libcurand.so.10", | ||||||||
| "libcufft.so.10", | ||||||||
| "libcudart.so.11", | ||||||||
| "libcudnn.so.8", | ||||||||
| "libnvrtc.so.11.2", | ||||||||
| # This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc. | ||||||||
| ) | ||||||||
| elif (12, 0) <= cuda_version_ < (13, 0): | ||||||||
| cuda_libs = ( | ||||||||
| "libcublaslt.so.12", | ||||||||
| "libcublas.so.12", | ||||||||
| "libcurand.so.10", | ||||||||
| "libcufft.so.11", | ||||||||
| "libcudart.so.12", | ||||||||
| "libcudnn.so.9", | ||||||||
| "libnvrtc.so.12", | ||||||||
| ) | ||||||||
| else: | ||||||||
| logging.info(f"Unsupported platform: {platform.system()}") | ||||||||
| check_and_load_cuda_libs(nvidia_path, cuda_libs) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move these code to a function like Example usage: |
||||||||
Uh oh!
There was an error while loading. Please reload this page.