Skip to content

Commit a0d1dc6

Browse files
committed
add Embedding Similarity function
1 parent 9470e5e commit a0d1dc6

8 files changed

Lines changed: 61319 additions & 0 deletions

File tree

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@
358358
<!-- <artifactId>dotenv-java</artifactId>-->
359359
<!-- <version>2.3.2</version>-->
360360
<!-- </dependency>-->
361+
<!-- ONNX Runtime -->
362+
<dependency>
363+
<groupId>com.microsoft.onnxruntime</groupId>
364+
<artifactId>onnxruntime</artifactId>
365+
<version>1.17.1</version>
366+
</dependency>
361367
</dependencies>
362368

363369
<build>
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package com.phantoms.phantomsbackend.common.utils;
2+
3+
import ai.onnxruntime.*;
4+
import java.util.Map;
5+
6+
public class LiteEmbeddingSimilarity {
7+
8+
// 全局唯一模型(只加载1次,内存固定)
9+
private static OrtEnvironment env;
10+
private static OrtSession session;
11+
12+
static {
13+
try {
14+
// 初始化 ONNX 环境(极轻量)
15+
env = OrtEnvironment.getEnvironment();
16+
// 模型路径(放入 resources 下)
17+
String modelPath = LiteEmbeddingSimilarity.class.getResource("/model/onnx/model.onnx").getPath();
18+
19+
// 会话配置(最小内存)
20+
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
21+
opts.setIntraOpNumThreads(1); // 单线程,省内存
22+
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.BASIC_OPT);
23+
24+
// 加载模型(22MB)
25+
session = env.createSession(modelPath, opts);
26+
} catch (Exception e) {
27+
e.printStackTrace();
28+
}
29+
}
30+
31+
// ===================== 核心:生成向量 =====================
32+
public static float[] getEmbedding(String text) throws Exception {
33+
try (OrtSession.Result result = session.run(Map.of(
34+
"input_ids", OnnxTensor.createTensor(env, new long[][]{{101, 102}}), // 简化示例,真实需替换为 token 结果
35+
"attention_mask", OnnxTensor.createTensor(env, new long[][]{{1, 1}})
36+
))) {
37+
return (float[]) result.get(0).getValue();
38+
}
39+
}
40+
41+
// ===================== 余弦相似度 =====================
42+
public static double cosineSim(float[] a, float[] b) {
43+
double dot = 0, normA = 0, normB = 0;
44+
for (int i = 0; i < a.length; i++) {
45+
dot += a[i] * b[i];
46+
normA += a[i] * a[i];
47+
normB += b[i] * b[i];
48+
}
49+
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
50+
}
51+
52+
// ===================== 对外接口 =====================
53+
public static double similarity(String name1, String name2) throws Exception {
54+
float[] v1 = getEmbedding(name1);
55+
float[] v2 = getEmbedding(name2);
56+
return cosineSim(v1, v2);
57+
}
58+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"_name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
3+
"architectures": [
4+
"BertModel"
5+
],
6+
"attention_probs_dropout_prob": 0.1,
7+
"classifier_dropout": null,
8+
"gradient_checkpointing": false,
9+
"hidden_act": "gelu",
10+
"hidden_dropout_prob": 0.1,
11+
"hidden_size": 384,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 1536,
14+
"layer_norm_eps": 1e-12,
15+
"max_position_embeddings": 512,
16+
"model_type": "bert",
17+
"num_attention_heads": 12,
18+
"num_hidden_layers": 6,
19+
"pad_token_id": 0,
20+
"position_embedding_type": "absolute",
21+
"transformers_version": "4.27.4",
22+
"type_vocab_size": 2,
23+
"use_cache": true,
24+
"vocab_size": 30522
25+
}
86.2 MB
Binary file not shown.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"cls_token": "[CLS]",
3+
"mask_token": "[MASK]",
4+
"pad_token": "[PAD]",
5+
"sep_token": "[SEP]",
6+
"unk_token": "[UNK]"
7+
}

0 commit comments

Comments
 (0)