1515
1616namespace BetterGenshinImpact . Core . Recognition . ONNX ;
1717
18- public class BgiOnnxFactory : Singleton < BgiOnnxFactory >
18+ public class BgiOnnxFactory
1919{
20- private static readonly ILogger < BgiOnnxFactory > Logger = App . GetLogger < BgiOnnxFactory > ( ) ;
20+ private readonly ILogger logger ;
2121
2222 /// <summary>
2323 /// 缓存模型路径。如果一开始使用缓存就一直使用缓存文件,如果没有使用缓存就一直使用原始模型路径。
@@ -26,9 +26,15 @@ public class BgiOnnxFactory : Singleton<BgiOnnxFactory>
2626 /// </summary>
2727 private readonly ConcurrentDictionary < BgiOnnxModel , string ? > _cachedModelPaths = new ( ) ;
2828
29- public BgiOnnxFactory ( )
29+ /// <summary>
30+ /// 请勿直接实例化此类
31+ /// </summary>
32+ /// <param name="hardwareAccelerationConfig"></param>
33+ /// <param name="logger"></param>
34+ public BgiOnnxFactory ( HardwareAccelerationConfig hardwareAccelerationConfig , ILogger < BgiOnnxFactory > logger )
3035 {
31- var config = TaskContext . Instance ( ) . Config . HardwareAccelerationConfig ;
36+ var config = hardwareAccelerationConfig ;
37+ this . logger = logger ;
3238 if ( config . AutoAppendCudaPath ) AppendCudaPath ( ) ;
3339
3440 if ( string . IsNullOrWhiteSpace ( config . AdditionalPath ) )
@@ -41,9 +47,9 @@ public BgiOnnxFactory()
4147 TrtUseEmbedMode = config . EmbedTensorRtCache ;
4248 EnableCache = config . EnableTensorRtCache ;
4349 CpuOcr = config . CpuOcr ;
44- Logger . LogDebug (
50+ this . logger . LogDebug (
4551 "[ONNX]启用的provider:{Device},初始化参数: InferenceDevice={InferenceDevice}, OptimizedModel={OptimizedModel}, CudaDeviceId={CudaDeviceId}, DmlDeviceId={DmlDeviceId}, EmbedTensorRtCache={EmbedTensorRtCache}, EnableTensorRtCache={EnableTensorRtCache}, CpuOcr={CpuOcr}" ,
46- string . Join ( "," , ProviderTypes . Select ( Enum . GetName ) ) ,
52+ string . Join ( "," , ProviderTypes . Select < ProviderType , string > ( Enum . GetName ) ) ,
4753 config . InferenceDevice ,
4854 OptimizedModel ,
4955 CudaDeviceId ,
@@ -70,7 +76,7 @@ public BgiOnnxFactory()
7076 /// <param name="dmlDeviceId">dml设备id</param>
7177 /// <returns></returns>
7278 /// <exception cref="InvalidEnumArgumentException"></exception>
73- private static ProviderType [ ] GetProviderType ( InferenceDeviceType inferenceDeviceType , int cudaDeviceId ,
79+ private ProviderType [ ] GetProviderType ( InferenceDeviceType inferenceDeviceType , int cudaDeviceId ,
7480 int dmlDeviceId )
7581 {
7682 switch ( inferenceDeviceType )
@@ -94,7 +100,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
94100 }
95101 catch ( Exception e )
96102 {
97- Logger . LogDebug ( "[init]无法加载TensorRt。可能不支持,跳过。({Err})" , e . Message ) ;
103+ logger . LogDebug ( "[init]无法加载TensorRt。可能不支持,跳过。({Err})" , e . Message ) ;
98104 }
99105 finally
100106 {
@@ -112,7 +118,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
112118 }
113119 catch ( Exception e )
114120 {
115- Logger . LogDebug ( "[init]无法加载DML。可能不支持,跳过。({Err})" , e . Message ) ;
121+ logger . LogDebug ( "[init]无法加载DML。可能不支持,跳过。({Err})" , e . Message ) ;
116122 }
117123 finally
118124 {
@@ -129,14 +135,14 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
129135 }
130136 catch ( Exception e )
131137 {
132- Logger . LogDebug ( "[init]无法加载Cuda。可能不支持,跳过。({Err})" , e . Message ) ;
138+ logger . LogDebug ( "[init]无法加载Cuda。可能不支持,跳过。({Err})" , e . Message ) ;
133139 }
134140 finally
135141 {
136142 testSession ? . Dispose ( ) ;
137143 }
138144
139- if ( ! hasGpu ) Logger . LogWarning ( "[init]GPU自动选择失败,回退到CPU处理" ) ;
145+ if ( ! hasGpu ) logger . LogWarning ( "[init]GPU自动选择失败,回退到CPU处理" ) ;
140146
141147 //无论如何都要加入cpu,一些计算在纯gpu上不被支持或性能很烂
142148 list . Add ( ProviderType . Cpu ) ;
@@ -149,7 +155,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
149155 /// <summary>
150156 /// 自动嗅探并修改path以加载cuda
151157 /// </summary>
152- private static void AppendCudaPath ( )
158+ private void AppendCudaPath ( )
153159 {
154160 var cudaVersion =
155161 Registry . GetValue ( @"HKEY_LOCAL_MACHINE\SOFTWARE\NVIDIA Corporation\GPU Computing Toolkit\CUDA" ,
@@ -203,7 +209,7 @@ private static void AppendCudaPath()
203209 /// 将附加的path应用进来
204210 /// </summary>
205211 /// <param name="extraPath">附加的path字符串</param>
206- private static void AppendPath ( string [ ] extraPath )
212+ private void AppendPath ( string [ ] extraPath )
207213 {
208214 if ( extraPath . Length <= 0 ) return ;
209215
@@ -212,12 +218,12 @@ private static void AppendPath(string[] extraPath)
212218 pathVariables . AddRange ( extraPath ) ;
213219 if ( pathVariables . Count <= 0 )
214220 {
215- Logger . LogWarning ( "[GpuAuto]SetCudaPath:No valid paths found." ) ;
221+ logger . LogWarning ( "[GpuAuto]SetCudaPath:No valid paths found." ) ;
216222 return ;
217223 }
218224
219225 var updatedPath = string . Join ( Path . PathSeparator , pathVariables . Distinct ( ) ) ;
220- Logger . LogDebug ( "[GpuAuto]修改进程PATH为:{UpdatedPath}" , updatedPath ) ;
226+ logger . LogDebug ( "[GpuAuto]修改进程PATH为:{UpdatedPath}" , updatedPath ) ;
221227 Environment . SetEnvironmentVariable ( "PATH" , updatedPath , EnvironmentVariableTarget . Process ) ;
222228 }
223229
@@ -228,7 +234,7 @@ private static void AppendPath(string[] extraPath)
228234 /// <returns>BgiYoloPredictor</returns>
229235 public BgiYoloPredictor CreateYoloPredictor ( BgiOnnxModel model )
230236 {
231- Logger . LogDebug ( "[Yolo]创建yolo预测器,模型: {ModelName}" , model . Name ) ;
237+ logger . LogDebug ( "[Yolo]创建yolo预测器,模型: {ModelName}" , model . Name ) ;
232238 if ( ! EnableCache ) return new BgiYoloPredictor ( model , model . ModalPath , CreateSessionOptions ( model , false ) ) ;
233239
234240 var cached = GetCached ( model ) ;
@@ -245,7 +251,7 @@ public BgiYoloPredictor CreateYoloPredictor(BgiOnnxModel model)
245251 /// <returns>InferenceSession</returns>
246252 public InferenceSession CreateInferenceSession ( BgiOnnxModel model , bool ocr = false )
247253 {
248- Logger . LogDebug ( "[ONNX]创建推理会话,模型: {ModelName}" , model . Name ) ;
254+ logger . LogDebug ( "[ONNX]创建推理会话,模型: {ModelName}" , model . Name ) ;
249255 ProviderType [ ] ? providerTypes = null ;
250256 if ( CpuOcr && ocr ) providerTypes = [ ProviderType . Cpu ] ;
251257
@@ -275,7 +281,7 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
275281 // 判断文件是否存在
276282 if ( File . Exists ( result ) ) return result ;
277283
278- Logger . LogWarning ( "[ONNX]模型 {Model} 的缓存文件可能已被删除,使用原始模型文件。" , model . Name ) ;
284+ logger . LogWarning ( "[ONNX]模型 {Model} 的缓存文件可能已被删除,使用原始模型文件。" , model . Name ) ;
279285 return null ;
280286 }
281287
@@ -289,19 +295,19 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
289295 var ctxA = Path . Combine ( model . CachePath , "trt" , "_ctx.onnx" ) ;
290296 if ( File . Exists ( ctxA ) )
291297 {
292- Logger . LogDebug ( "[ONNX]模型 {Model} 命中TRT匿名缓存文件: {Path}" , model . Name , ctxA ) ;
298+ logger . LogDebug ( "[ONNX]模型 {Model} 命中TRT匿名缓存文件: {Path}" , model . Name , ctxA ) ;
293299 return ctxA ;
294300 }
295301
296302 var ctxB = Path . Combine ( model . CachePath , "trt" ,
297303 Path . GetFileNameWithoutExtension ( model . ModalPath ) + "_ctx.onnx" ) ;
298304 if ( File . Exists ( ctxB ) )
299305 {
300- Logger . LogDebug ( "[ONNX]模型 {Model} 命中TRT命名缓存文件: {Path}" , model . Name , ctxB ) ;
306+ logger . LogDebug ( "[ONNX]模型 {Model} 命中TRT命名缓存文件: {Path}" , model . Name , ctxB ) ;
301307 return ctxB ;
302308 }
303309
304- Logger . LogDebug ( "[ONNX]没有找到模型 {Model} 的模型缓存文件。" , model . Name ) ;
310+ logger . LogDebug ( "[ONNX]没有找到模型 {Model} 的模型缓存文件。" , model . Name ) ;
305311 return null ;
306312 }
307313
@@ -315,7 +321,7 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
315321 /// <param name="forcedProvider">强制使用的Provider,为空或null则不强制</param>
316322 /// <returns></returns>
317323 /// <exception cref="InvalidEnumArgumentException"></exception>
318- private SessionOptions CreateSessionOptions ( BgiOnnxModel path , bool genCache , ProviderType [ ] ? forcedProvider = null )
324+ protected SessionOptions CreateSessionOptions ( BgiOnnxModel path , bool genCache , ProviderType [ ] ? forcedProvider = null )
319325 {
320326 var sessionOptions = new SessionOptions ( ) ;
321327 foreach ( var type in
@@ -355,7 +361,7 @@ private SessionOptions CreateSessionOptions(BgiOnnxModel path, bool genCache, Pr
355361 }
356362 catch ( Exception e )
357363 {
358- Logger . LogError ( "无法加载指定的 ONNX provider {Provider},跳过。请检查推理设备配置是否正确。({Err})" , Enum . GetName ( type ) ,
364+ logger . LogError ( "无法加载指定的 ONNX provider {Provider},跳过。请检查推理设备配置是否正确。({Err})" , Enum . GetName ( type ) ,
359365 e . Message ) ;
360366 }
361367
0 commit comments