7878onnxruntime_validation .check_distro_info ()
7979
8080
81- def check_and_load_cuda_libs (root_directory , cuda_libs ):
81+ def check_and_load_cuda_libs (root_directory , cuda_libs_ ):
8282 # Convert the target library names to lowercase for case-insensitive comparison
83+ # Convert the target library names to lowercase for case-insensitive comparison
84+ if cuda_libs_ is None or len (cuda_libs_ ) == 0 :
85+ logging .debug ("No CUDA libraries provided for loading." )
86+ return
87+ cuda_libs_ = {lib .lower () for lib in cuda_libs_ }
8388 found_libs = {}
8489 for dirpath , _ , filenames in os .walk (root_directory ):
8590 # Convert filenames in the current directory to lowercase for comparison
8691 files_in_dir = {file .lower (): file for file in filenames } # Map lowercase to original
8792 # Find common libraries in the current directory
88- matched_libs = cuda_libs .intersection (files_in_dir .keys ())
93+ matched_libs = cuda_libs_ .intersection (files_in_dir .keys ())
8994 for lib in matched_libs :
9095 # Store the full path of the found DLL
9196 full_path = os .path .join (dirpath , files_in_dir [lib ])
@@ -98,11 +103,11 @@ def check_and_load_cuda_libs(root_directory, cuda_libs):
98103 logging .error (f"Failed to load { full_path } : { e } " )
99104
100105 # If all required libraries are found, stop the search
101- if set (found_libs .keys ()) == cuda_libs :
106+ if set (found_libs .keys ()) == cuda_libs_ :
102107 print ("All required CUDA libraries found and loaded." )
103- return True
104- logging .error (f"Failed to load all required CUDA libraries. missing libraries: { cuda_libs - found_libs .keys ()} " )
105- return False
108+ return
109+ logging .error (f"Failed to load all required CUDA libraries. missing libraries: { cuda_libs_ - found_libs .keys ()} " )
110+ return
106111
107112
108113# Load nvidia libraries from site-packages/nvidia if the package is onnxruntime-gpu
@@ -117,7 +122,6 @@ def check_and_load_cuda_libs(root_directory, cuda_libs):
117122 import logging
118123 import os
119124 import platform
120- import re
121125 import site
122126
123127 # Get the site-packages path where nvidia packages are installed
@@ -129,15 +133,15 @@ def check_and_load_cuda_libs(root_directory, cuda_libs):
129133 # Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows
130134 if (11 , 0 ) <= cuda_version () < (12 , 0 ):
131135 cuda_libs = (
132- "cublasLT64_11 .dll" ,
136+ "cublaslt64_11 .dll" ,
133137 "cublas64_11.dll" ,
134138 "cufft64_10.dll" ,
135139 "cudart64_11.dll" ,
136140 "cudnn64_8.dll" ,
137141 )
138142 elif (12 , 0 ) <= cuda_version () < (13 , 0 ):
139143 cuda_libs = (
140- "cublasLT64_12 .dll" ,
144+ "cublaslt64_12 .dll" ,
141145 "cublas64_12.dll" ,
142146 "cufft64_11.dll" ,
143147 "cudart64_12.dll" ,
@@ -166,10 +170,5 @@ def check_and_load_cuda_libs(root_directory, cuda_libs):
166170 "libnvrtc.so.12" ,
167171 )
168172 else :
169- logging .error (f"Unsupported platform: { platform .system ()} " )
170-
171- if cuda_libs :
172- # Convert the target library names to lowercase for case-insensitive comparison
173- cuda_libs = {lib .lower () for lib in cuda_libs }
174- # Load the required CUDA libraries
175- check_and_load_cuda_libs (nvidia_path , cuda_libs )
173+ logging .debug (f"Unsupported platform: { platform .system ()} " )
174+ check_and_load_cuda_libs (nvidia_path , cuda_libs )
0 commit comments