-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathonnx.go
More file actions
96 lines (82 loc) · 2.32 KB
/
onnx.go
File metadata and controls
96 lines (82 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package vision
import (
"fmt"
ort "github.com/getcharzp/onnxruntime_purego"
"runtime"
"sync"
)
type OnnxConfig struct {
SessionOptions *ort.SessionOptions
OnnxEngine *ort.Engine
// 必填参数
OnnxRuntimeLibPath string // onnxruntime.dll (或 .so, .dylib) 的路径
// 可选参数
UseCuda bool // (可选) 是否启用 CUDA
NumThreads int // (可选) ONNX 线程数, 默认由CPU核心数决定
// EnableCpuMemArena 控制 ONNX 的内存池策略
// false (默认): 禁用内存池,推理速度稍慢,但 Destroy 后立即归还内存给 OS ,解决内存滞留问题
// true: 启用内存池,推理速度最快,但 Destroy 后内存会被缓存以供复用
EnableCpuMemArena bool
}
var (
initErr error
once sync.Once
onnxEngine *ort.Engine
)
// New 初始化 ONNX 环境
func (cfg *OnnxConfig) New() error {
// 初始化 ONNX Runtime
if cfg.OnnxRuntimeLibPath == "" {
return fmt.Errorf("OnnxRuntimeLibPath 不能为空")
}
once.Do(func() {
onnxEngine, initErr = ort.NewEngine(cfg.OnnxRuntimeLibPath)
})
if initErr != nil {
return fmt.Errorf("初始化 ONNX Engine 失败: %w", initErr)
}
// 创建会话选项 (设置线程)
options, err := onnxEngine.NewSessionOptions()
if err != nil {
return err
}
if cfg.NumThreads > 0 {
if err := options.SetIntraOpNumThreads(int32(cfg.NumThreads)); err != nil {
return err
}
}
// 设置内存策略
if err := options.SetCpuMemArena(cfg.EnableCpuMemArena); err != nil {
return fmt.Errorf("设置 CPU 内存池失败: %w", err)
}
// 启用CUDA
if cfg.UseCuda {
if err := options.EnableCUDA(); err != nil {
return fmt.Errorf("启用 CUDA 失败: %w", err)
}
}
cfg.SessionOptions = options
cfg.OnnxEngine = onnxEngine
return nil
}
// DefaultLibraryPath 根据运行时环境判断加载哪个库文件
func DefaultLibraryPath() string {
baseDir := "./lib/"
libName := "onnxruntime"
// windows onnxruntime.dll
if runtime.GOOS == "windows" {
return baseDir + libName + ".dll"
}
// linux darwin ext
var ext string
switch runtime.GOOS {
case "darwin":
ext = "dylib"
case "linux":
ext = "so"
default:
return baseDir + libName + "_amd64.so" // 默认返回 linux amd64
}
// 拼接完整路径: ./lib/onnxruntime + _ + amd64/arm64 + . + so/dylib
return fmt.Sprintf("%s%s_%s.%s", baseDir, libName, runtime.GOARCH, ext)
}