Skip to content

Commit 0c02808

Browse files
使用TorchSharp重写RodNet,以利后续优化 (#1613)
* 使用TorchSharp重写RodNet,以利后续优化 * 增加一个外部torch加载配置来代替直接的依赖,如配置不生效则使用原先手搓的算法 * BgiOnnxFactory取消单例,改为在App服务类中注册为单例,由此修复了一堆单元测试 * BgiOnnxFactory中几个静态方法改为成员方法以和App解耦;因不再有多个mat源供消耗,FishBite中文字块算法不再改动传入的mat,使得后续串联的算法不受其影响 * 将BehavioursTests中临时的配置读取方式改为读取主项目编译环境中的json文件;新建单元测试的README * 将RodNet算法更新到 myHuTao-qwq/HutaoFisher@010006a 的版本;RodNet中关于torch库推理和直接数学计算的校验移至单元测试 * 更新RodNet算法至最新:myHuTao-qwq/HutaoFisher@add5672 * 注释调试用的代码
1 parent eae02d6 commit 0c02808

33 files changed

+523
-229
lines changed

BetterGenshinImpact/App.xaml.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Threading.Tasks;
55
using System.Windows;
66
using System.Windows.Threading;
7+
using BetterGenshinImpact.Core.Recognition.ONNX;
78
using BetterGenshinImpact.GameTask;
89
using BetterGenshinImpact.Helpers;
910
using BetterGenshinImpact.Helpers.Extensions;
@@ -125,13 +126,17 @@ public partial class App : Application
125126
services.AddSingleton<NotifierManager>();
126127
services.AddSingleton<IScriptService, ScriptService>();
127128
services.AddSingleton<HutaoNamedPipe>();
129+
services.AddSingleton(sp=> sp.GetRequiredService<HomePageViewModel>().Config.HardwareAccelerationConfig);
130+
services.AddSingleton<BgiOnnxFactory>();
128131

129132
// Configuration
130133
//services.Configure<AppConfig>(context.Configuration.GetSection(nameof(AppConfig)));
131134
}
132135
)
133136
.Build();
134137

138+
public static IServiceProvider ServiceProvider => _host.Services;
139+
135140
public static ILogger<T> GetLogger<T>()
136141
{
137142
return _host.Services.GetService<ILogger<T>>()!;

BetterGenshinImpact/BetterGenshinImpact.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
<PackageReference Include="Microsoft.Extensions.Logging" Version="9.0.4" />
5555
<PackageReference Include="Microsoft.ML.OnnxRuntime.DirectML" Version="1.21.0" />
5656
<!--排除掉cpu的runtime dll-->
57-
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.21.0" IncludeAssets="none"/>
57+
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.21.0" IncludeAssets="none" />
5858
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.21.0" />
5959
<PackageReference Include="Microsoft.Toolkit.Uwp.Notifications" Version="7.1.3" />
6060
<PackageReference Include="Microsoft.Web.WebView2" Version="1.0.2592.51" />
@@ -74,6 +74,7 @@
7474
<PackageReference Include="Serilog.Sinks.RichTextBoxEx.Wpf" Version="1.1.0.1" />
7575
<PackageReference Include="System.Drawing.Common" Version="9.0.5" />
7676
<PackageReference Include="System.IO.Hashing" Version="9.0.4" />
77+
<PackageReference Include="TorchSharp" Version="0.105.0" />
7778
<PackageReference Include="Vanara.PInvoke.NtDll" Version="4.1.3" />
7879
<PackageReference Include="Vanara.PInvoke.SHCore" Version="4.1.3" />
7980
<PackageReference Include="Vanara.PInvoke.User32" Version="4.1.3" />

BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
using System.Collections.Concurrent;
33
using System.Collections.Generic;
44
using System.Threading.Tasks;
5+
using BetterGenshinImpact.Core.Recognition.ONNX;
56
using BetterGenshinImpact.GameTask;
7+
using Microsoft.Extensions.DependencyInjection;
68
using Microsoft.Extensions.Logging;
79

810
namespace BetterGenshinImpact.Core.Recognition.OCR;
@@ -34,7 +36,7 @@ private static KeyValuePair<string, IOcrService> CreateAndSet(OcrEngineTypes typ
3436
var result = type switch
3537
{
3638
OcrEngineTypes.Paddle => new KeyValuePair<string, IOcrService>(cultureInfoName,
37-
new PaddleOcrService(cultureInfoName)),
39+
new PaddleOcrService(cultureInfoName, App.ServiceProvider.GetRequiredService<BgiOnnxFactory>())),
3840
_ => throw new ArgumentOutOfRangeException(nameof(type), type, null)
3941
};
4042
Logger.LogDebug("为 {CultureInfoName} 创建了类型为 {Type} 的 OCR服务", result.Key, result.Value);

BetterGenshinImpact/Core/Recognition/OCR/paddle/Det.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ public class Det
1313
private readonly OcrVersionConfig _config;
1414
private readonly InferenceSession _session;
1515

16-
public Det(BgiOnnxModel model, OcrVersionConfig config)
16+
public Det(BgiOnnxModel model, OcrVersionConfig config, BgiOnnxFactory bgiOnnxFactory)
1717
{
1818
_config = config;
19-
_session = BgiOnnxFactory.Instance.CreateInferenceSession(model, true);
19+
_session = bgiOnnxFactory.CreateInferenceSession(model, true);
2020
}
2121

2222
/// <summary>Gets or sets the maximum size for resizing the input image.</summary>

BetterGenshinImpact/Core/Recognition/OCR/paddle/PaddleOcrService.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,26 @@ public class PaddleOcrService : IOcrService
2222

2323
private readonly Rec _localRecModel;
2424

25-
public PaddleOcrService(string cultureInfoName)
25+
public PaddleOcrService(string cultureInfoName, BgiOnnxFactory bgiOnnxFactory)
2626
{
2727
var path = Global.Absolute(@"Assets\Model\PaddleOcr");
2828

2929
switch (cultureInfoName)
3030
{
3131
case "zh-Hant":
32-
_localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4);
32+
_localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4, bgiOnnxFactory);
3333
_localRecModel = new Rec(BgiOnnxModel.PaddleOcrChtRec, Path.Combine(path, "chinese_cht_dict.txt"),
34-
OcrVersionConfig.PpOcrV3);
34+
OcrVersionConfig.PpOcrV3, bgiOnnxFactory);
3535
break;
3636
case "fr":
37-
_localDetModel = new Det(BgiOnnxModel.PaddleOcrEnDet, OcrVersionConfig.PpOcrV3);
37+
_localDetModel = new Det(BgiOnnxModel.PaddleOcrEnDet, OcrVersionConfig.PpOcrV3, bgiOnnxFactory);
3838
_localRecModel = new Rec(BgiOnnxModel.PaddleOcrLatinRec, Path.Combine(path, "latin_dict.txt"),
39-
OcrVersionConfig.PpOcrV3);
39+
OcrVersionConfig.PpOcrV3, bgiOnnxFactory);
4040
break;
4141
default:
42-
_localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4);
42+
_localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4, bgiOnnxFactory);
4343
_localRecModel = new Rec(BgiOnnxModel.PaddleOcrChRec, Path.Combine(path, "ppocr_keys_v1.txt"),
44-
OcrVersionConfig.PpOcrV4);
44+
OcrVersionConfig.PpOcrV4, bgiOnnxFactory);
4545

4646
break;
4747
}

BetterGenshinImpact/Core/Recognition/OCR/paddle/Rec.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ public class Rec
2020
private readonly IReadOnlyList<string> _labels;
2121
private readonly InferenceSession _session;
2222

23-
public Rec(BgiOnnxModel model, string labelFilePath, OcrVersionConfig config)
23+
public Rec(BgiOnnxModel model, string labelFilePath, OcrVersionConfig config, BgiOnnxFactory bgiOnnxFactory)
2424
{
2525
_config = config;
26-
_session = BgiOnnxFactory.Instance.CreateInferenceSession(model, true);
26+
_session = bgiOnnxFactory.CreateInferenceSession(model, true);
2727

2828

2929
_labels = File.ReadAllLines(labelFilePath);

BetterGenshinImpact/Core/Recognition/ONNX/BgiOnnxFactory.cs

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
namespace BetterGenshinImpact.Core.Recognition.ONNX;
1717

18-
public class BgiOnnxFactory : Singleton<BgiOnnxFactory>
18+
public class BgiOnnxFactory
1919
{
20-
private static readonly ILogger<BgiOnnxFactory> Logger = App.GetLogger<BgiOnnxFactory>();
20+
private readonly ILogger logger;
2121

2222
/// <summary>
2323
/// 缓存模型路径。如果一开始使用缓存就一直使用缓存文件,如果没有使用缓存就一直使用原始模型路径。
@@ -26,9 +26,15 @@ public class BgiOnnxFactory : Singleton<BgiOnnxFactory>
2626
/// </summary>
2727
private readonly ConcurrentDictionary<BgiOnnxModel, string?> _cachedModelPaths = new();
2828

29-
public BgiOnnxFactory()
29+
/// <summary>
30+
/// 请勿直接实例化此类
31+
/// </summary>
32+
/// <param name="hardwareAccelerationConfig"></param>
33+
/// <param name="logger"></param>
34+
public BgiOnnxFactory(HardwareAccelerationConfig hardwareAccelerationConfig, ILogger<BgiOnnxFactory> logger)
3035
{
31-
var config = TaskContext.Instance().Config.HardwareAccelerationConfig;
36+
var config = hardwareAccelerationConfig;
37+
this.logger = logger;
3238
if (config.AutoAppendCudaPath) AppendCudaPath();
3339

3440
if (string.IsNullOrWhiteSpace(config.AdditionalPath))
@@ -41,9 +47,9 @@ public BgiOnnxFactory()
4147
TrtUseEmbedMode = config.EmbedTensorRtCache;
4248
EnableCache = config.EnableTensorRtCache;
4349
CpuOcr = config.CpuOcr;
44-
Logger.LogDebug(
50+
this.logger.LogDebug(
4551
"[ONNX]启用的provider:{Device},初始化参数: InferenceDevice={InferenceDevice}, OptimizedModel={OptimizedModel}, CudaDeviceId={CudaDeviceId}, DmlDeviceId={DmlDeviceId}, EmbedTensorRtCache={EmbedTensorRtCache}, EnableTensorRtCache={EnableTensorRtCache}, CpuOcr={CpuOcr}",
46-
string.Join(",", ProviderTypes.Select(Enum.GetName)),
52+
string.Join(",", ProviderTypes.Select<ProviderType, string>(Enum.GetName)),
4753
config.InferenceDevice,
4854
OptimizedModel,
4955
CudaDeviceId,
@@ -70,7 +76,7 @@ public BgiOnnxFactory()
7076
/// <param name="dmlDeviceId">dml设备id</param>
7177
/// <returns></returns>
7278
/// <exception cref="InvalidEnumArgumentException"></exception>
73-
private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDeviceType, int cudaDeviceId,
79+
private ProviderType[] GetProviderType(InferenceDeviceType inferenceDeviceType, int cudaDeviceId,
7480
int dmlDeviceId)
7581
{
7682
switch (inferenceDeviceType)
@@ -94,7 +100,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
94100
}
95101
catch (Exception e)
96102
{
97-
Logger.LogDebug("[init]无法加载TensorRt。可能不支持,跳过。({Err})", e.Message);
103+
logger.LogDebug("[init]无法加载TensorRt。可能不支持,跳过。({Err})", e.Message);
98104
}
99105
finally
100106
{
@@ -112,7 +118,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
112118
}
113119
catch (Exception e)
114120
{
115-
Logger.LogDebug("[init]无法加载DML。可能不支持,跳过。({Err})", e.Message);
121+
logger.LogDebug("[init]无法加载DML。可能不支持,跳过。({Err})", e.Message);
116122
}
117123
finally
118124
{
@@ -129,14 +135,14 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
129135
}
130136
catch (Exception e)
131137
{
132-
Logger.LogDebug("[init]无法加载Cuda。可能不支持,跳过。({Err})", e.Message);
138+
logger.LogDebug("[init]无法加载Cuda。可能不支持,跳过。({Err})", e.Message);
133139
}
134140
finally
135141
{
136142
testSession?.Dispose();
137143
}
138144

139-
if (!hasGpu) Logger.LogWarning("[init]GPU自动选择失败,回退到CPU处理");
145+
if (!hasGpu) logger.LogWarning("[init]GPU自动选择失败,回退到CPU处理");
140146

141147
//无论如何都要加入cpu,一些计算在纯gpu上不被支持或性能很烂
142148
list.Add(ProviderType.Cpu);
@@ -149,7 +155,7 @@ private static ProviderType[] GetProviderType(InferenceDeviceType inferenceDevic
149155
/// <summary>
150156
/// 自动嗅探并修改path以加载cuda
151157
/// </summary>
152-
private static void AppendCudaPath()
158+
private void AppendCudaPath()
153159
{
154160
var cudaVersion =
155161
Registry.GetValue(@"HKEY_LOCAL_MACHINE\SOFTWARE\NVIDIA Corporation\GPU Computing Toolkit\CUDA",
@@ -203,7 +209,7 @@ private static void AppendCudaPath()
203209
/// 将附加的path应用进来
204210
/// </summary>
205211
/// <param name="extraPath">附加的path字符串</param>
206-
private static void AppendPath(string[] extraPath)
212+
private void AppendPath(string[] extraPath)
207213
{
208214
if (extraPath.Length <= 0) return;
209215

@@ -212,12 +218,12 @@ private static void AppendPath(string[] extraPath)
212218
pathVariables.AddRange(extraPath);
213219
if (pathVariables.Count <= 0)
214220
{
215-
Logger.LogWarning("[GpuAuto]SetCudaPath:No valid paths found.");
221+
logger.LogWarning("[GpuAuto]SetCudaPath:No valid paths found.");
216222
return;
217223
}
218224

219225
var updatedPath = string.Join(Path.PathSeparator, pathVariables.Distinct());
220-
Logger.LogDebug("[GpuAuto]修改进程PATH为:{UpdatedPath}", updatedPath);
226+
logger.LogDebug("[GpuAuto]修改进程PATH为:{UpdatedPath}", updatedPath);
221227
Environment.SetEnvironmentVariable("PATH", updatedPath, EnvironmentVariableTarget.Process);
222228
}
223229

@@ -228,7 +234,7 @@ private static void AppendPath(string[] extraPath)
228234
/// <returns>BgiYoloPredictor</returns>
229235
public BgiYoloPredictor CreateYoloPredictor(BgiOnnxModel model)
230236
{
231-
Logger.LogDebug("[Yolo]创建yolo预测器,模型: {ModelName}", model.Name);
237+
logger.LogDebug("[Yolo]创建yolo预测器,模型: {ModelName}", model.Name);
232238
if (!EnableCache) return new BgiYoloPredictor(model, model.ModalPath, CreateSessionOptions(model, false));
233239

234240
var cached = GetCached(model);
@@ -245,7 +251,7 @@ public BgiYoloPredictor CreateYoloPredictor(BgiOnnxModel model)
245251
/// <returns>InferenceSession</returns>
246252
public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = false)
247253
{
248-
Logger.LogDebug("[ONNX]创建推理会话,模型: {ModelName}", model.Name);
254+
logger.LogDebug("[ONNX]创建推理会话,模型: {ModelName}", model.Name);
249255
ProviderType[]? providerTypes = null;
250256
if (CpuOcr && ocr) providerTypes = [ProviderType.Cpu];
251257

@@ -275,7 +281,7 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
275281
// 判断文件是否存在
276282
if (File.Exists(result)) return result;
277283

278-
Logger.LogWarning("[ONNX]模型 {Model} 的缓存文件可能已被删除,使用原始模型文件。", model.Name);
284+
logger.LogWarning("[ONNX]模型 {Model} 的缓存文件可能已被删除,使用原始模型文件。", model.Name);
279285
return null;
280286
}
281287

@@ -289,19 +295,19 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
289295
var ctxA = Path.Combine(model.CachePath, "trt", "_ctx.onnx");
290296
if (File.Exists(ctxA))
291297
{
292-
Logger.LogDebug("[ONNX]模型 {Model} 命中TRT匿名缓存文件: {Path}", model.Name, ctxA);
298+
logger.LogDebug("[ONNX]模型 {Model} 命中TRT匿名缓存文件: {Path}", model.Name, ctxA);
293299
return ctxA;
294300
}
295301

296302
var ctxB = Path.Combine(model.CachePath, "trt",
297303
Path.GetFileNameWithoutExtension(model.ModalPath) + "_ctx.onnx");
298304
if (File.Exists(ctxB))
299305
{
300-
Logger.LogDebug("[ONNX]模型 {Model} 命中TRT命名缓存文件: {Path}", model.Name, ctxB);
306+
logger.LogDebug("[ONNX]模型 {Model} 命中TRT命名缓存文件: {Path}", model.Name, ctxB);
301307
return ctxB;
302308
}
303309

304-
Logger.LogDebug("[ONNX]没有找到模型 {Model} 的模型缓存文件。", model.Name);
310+
logger.LogDebug("[ONNX]没有找到模型 {Model} 的模型缓存文件。", model.Name);
305311
return null;
306312
}
307313

@@ -315,7 +321,7 @@ public InferenceSession CreateInferenceSession(BgiOnnxModel model, bool ocr = fa
315321
/// <param name="forcedProvider">强制使用的Provider,为空或null则不强制</param>
316322
/// <returns></returns>
317323
/// <exception cref="InvalidEnumArgumentException"></exception>
318-
private SessionOptions CreateSessionOptions(BgiOnnxModel path, bool genCache, ProviderType[]? forcedProvider = null)
324+
protected SessionOptions CreateSessionOptions(BgiOnnxModel path, bool genCache, ProviderType[]? forcedProvider = null)
319325
{
320326
var sessionOptions = new SessionOptions();
321327
foreach (var type in
@@ -355,7 +361,7 @@ private SessionOptions CreateSessionOptions(BgiOnnxModel path, bool genCache, Pr
355361
}
356362
catch (Exception e)
357363
{
358-
Logger.LogError("无法加载指定的 ONNX provider {Provider},跳过。请检查推理设备配置是否正确。({Err})", Enum.GetName(type),
364+
logger.LogError("无法加载指定的 ONNX provider {Provider},跳过。请检查推理设备配置是否正确。({Err})", Enum.GetName(type),
359365
e.Message);
360366
}
361367

BetterGenshinImpact/Core/Recognition/ONNX/SVTR/PickTextInference.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Text;
1111
using System.Text.Json;
1212
using BetterGenshinImpact.Core.Recognition.OCR.engine;
13+
using Microsoft.Extensions.DependencyInjection;
1314

1415
namespace BetterGenshinImpact.Core.Recognition.ONNX.SVTR;
1516

@@ -24,7 +25,7 @@ public class PickTextInference : ITextInference
2425

2526
public PickTextInference()
2627
{
27-
_session = BgiOnnxFactory.Instance.CreateInferenceSession(BgiOnnxModel.YapModelTraining,true);
28+
_session = App.ServiceProvider.GetRequiredService<BgiOnnxFactory>().CreateInferenceSession(BgiOnnxModel.YapModelTraining,true);
2829

2930
var wordJsonPath = Global.Absolute(@"Assets\Model\Yap\index_2_word.json");
3031
if (!File.Exists(wordJsonPath)) throw new FileNotFoundException("Yap字典文件不存在", wordJsonPath);

BetterGenshinImpact/GameTask/AutoDomain/AutoDomainTask.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
using System.Collections.ObjectModel;
4141
using BetterGenshinImpact.Core.Script.Dependence;
4242
using Compunet.YoloSharp;
43+
using Microsoft.Extensions.DependencyInjection;
4344

4445
namespace BetterGenshinImpact.GameTask.AutoDomain;
4546

@@ -72,7 +73,7 @@ public AutoDomainTask(AutoDomainParam taskParam)
7273
{
7374
AutoFightAssets.DestroyInstance();
7475
_taskParam = taskParam;
75-
_predictor = BgiOnnxFactory.Instance.CreateYoloPredictor(BgiOnnxModel.BgiTree);
76+
_predictor = App.ServiceProvider.GetRequiredService<BgiOnnxFactory>().CreateYoloPredictor(BgiOnnxModel.BgiTree);
7677

7778
_config = TaskContext.Instance().Config.AutoDomainConfig;
7879

BetterGenshinImpact/GameTask/AutoFight/AutoFightTask.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using BetterGenshinImpact.Helpers;
2020
using Vanara;
2121
using Vanara.PInvoke;
22+
using Microsoft.Extensions.DependencyInjection;
2223

2324

2425
namespace BetterGenshinImpact.GameTask.AutoFight;
@@ -185,7 +186,7 @@ public AutoFightTask(AutoFightParam taskParam)
185186

186187
if (_taskParam.FightFinishDetectEnabled)
187188
{
188-
_predictor = BgiOnnxFactory.Instance.CreateYoloPredictor(BgiOnnxModel.BgiWorld);
189+
_predictor = App.ServiceProvider.GetRequiredService<BgiOnnxFactory>().CreateYoloPredictor(BgiOnnxModel.BgiWorld);
189190
}
190191

191192
_finishDetectConfig = new TaskFightFinishDetectConfig(_taskParam.FinishDetectConfig);

0 commit comments

Comments
 (0)