1+ package com .phantoms .phantomsbackend .common .utils ;
2+
3+ import ai .onnxruntime .*;
4+ import java .util .Map ;
5+
6+ public class LiteEmbeddingSimilarity {
7+
8+ // 全局唯一模型(只加载1次,内存固定)
9+ private static OrtEnvironment env ;
10+ private static OrtSession session ;
11+
12+ static {
13+ try {
14+ // 初始化 ONNX 环境(极轻量)
15+ env = OrtEnvironment .getEnvironment ();
16+ // 模型路径(放入 resources 下)
17+ String modelPath = LiteEmbeddingSimilarity .class .getResource ("/model/onnx/model.onnx" ).getPath ();
18+
19+ // 会话配置(最小内存)
20+ OrtSession .SessionOptions opts = new OrtSession .SessionOptions ();
21+ opts .setIntraOpNumThreads (1 ); // 单线程,省内存
22+ opts .setOptimizationLevel (OrtSession .SessionOptions .OptLevel .BASIC_OPT );
23+
24+ // 加载模型(22MB)
25+ session = env .createSession (modelPath , opts );
26+ } catch (Exception e ) {
27+ e .printStackTrace ();
28+ }
29+ }
30+
31+ // ===================== 核心:生成向量 =====================
32+ public static float [] getEmbedding (String text ) throws Exception {
33+ try (OrtSession .Result result = session .run (Map .of (
34+ "input_ids" , OnnxTensor .createTensor (env , new long [][]{{101 , 102 }}), // 简化示例,真实需替换为 token 结果
35+ "attention_mask" , OnnxTensor .createTensor (env , new long [][]{{1 , 1 }})
36+ ))) {
37+ return (float []) result .get (0 ).getValue ();
38+ }
39+ }
40+
41+ // ===================== 余弦相似度 =====================
42+ public static double cosineSim (float [] a , float [] b ) {
43+ double dot = 0 , normA = 0 , normB = 0 ;
44+ for (int i = 0 ; i < a .length ; i ++) {
45+ dot += a [i ] * b [i ];
46+ normA += a [i ] * a [i ];
47+ normB += b [i ] * b [i ];
48+ }
49+ return dot / (Math .sqrt (normA ) * Math .sqrt (normB ));
50+ }
51+
52+ // ===================== 对外接口 =====================
53+ public static double similarity (String name1 , String name2 ) throws Exception {
54+ float [] v1 = getEmbedding (name1 );
55+ float [] v2 = getEmbedding (name2 );
56+ return cosineSim (v1 , v2 );
57+ }
58+ }
0 commit comments