Skip to content

Commit 19c5e24

Browse files
Liwx1014lyingbug
authored andcommitted
兼容解析重排序服务返回score字段
1 parent fb1cd98 commit 19c5e24

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

internal/models/rerank/reranker.go

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package rerank
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"strings"
7-
8+
89
"github.com/Tencent/WeKnora/internal/types"
910
)
1011

@@ -25,6 +26,33 @@ type RankResult struct {
2526
Document DocumentInfo `json:"document"`
2627
RelevanceScore float64 `json:"relevance_score"`
2728
}
29+
//Handles the RelevanceScore field by checking if RelevanceScore exists first, otherwise falls back to Score field
30+
func (r *RankResult) UnmarshalJSON(data []byte) error {
31+
32+
var temp struct {
33+
Index int `json:"index"`
34+
Document DocumentInfo `json:"document"`
35+
RelevanceScore *float64 `json:"relevance_score"`
36+
Score *float64 `json:"score"`
37+
}
38+
39+
40+
if err := json.Unmarshal(data, &temp); err != nil {
41+
return fmt.Errorf("failed to unmarshal rank result: %w", err)
42+
}
43+
44+
r.Index = temp.Index
45+
r.Document = temp.Document
46+
47+
if temp.RelevanceScore != nil {
48+
r.RelevanceScore = *temp.RelevanceScore
49+
} else if temp.Score != nil {
50+
r.RelevanceScore = *temp.Score
51+
}
52+
53+
54+
return nil
55+
}
2856

2957
type DocumentInfo struct {
3058
Text string `json:"text"`

rerank_server_demo.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch
2+
import uvicorn
3+
from fastapi import FastAPI
4+
from pydantic import BaseModel, Field
5+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
6+
from typing import List
7+
8+
# --- 1. 定义API的请求和响应数据结构 ---
9+
10+
# 请求体结构保持不变
11+
class RerankRequest(BaseModel):
12+
query: str
13+
documents: List[str]
14+
15+
# --- 修改开始:定义测试用的响应结构,字段名为 "score" ---
16+
17+
# DocumentInfo 结构保持不变
18+
class DocumentInfo(BaseModel):
19+
text: str
20+
21+
# 将原来的 GoRankResult 修改为 TestRankResult
22+
# 核心改动:将 "relevance_score" 字段重命名为 "score"
23+
class TestRankResult(BaseModel):
24+
index: int
25+
document: DocumentInfo
26+
score: float # <--- 【关键修改点】字段名已从 relevance_score 改为 score
27+
28+
# 最终响应体结构,其 "results" 列表包含的是 TestRankResult
29+
class TestFinalResponse(BaseModel):
30+
results: List[TestRankResult]
31+
32+
# --- 修改结束 ---
33+
34+
35+
# --- 2. 加载模型 (在服务启动时执行一次) ---
36+
print("正在加载模型,请稍候...")
37+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38+
print(f"使用的设备: {device}")
39+
try:
40+
# 请确保这里的路径是正确的
41+
model_path = '/data1/home/lwx/work/Download/rerank_model_weight'
42+
tokenizer = AutoTokenizer.from_pretrained(model_path)
43+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
44+
model.to(device)
45+
model.eval()
46+
print("模型加载成功!")
47+
except Exception as e:
48+
print(f"模型加载失败: {e}")
49+
# 在测试环境中,如果模型加载失败,可以考虑退出以避免运行一个无效的服务
50+
exit()
51+
52+
# --- 3. 创建FastAPI应用 ---
53+
app = FastAPI(
54+
title="Reranker API (Test Version)",
55+
description="一个返回 'score' 字段以测试Go客户端兼容性的API服务",
56+
version="1.0.1"
57+
)
58+
59+
# --- 4. 定义API端点 ---
60+
# --- 修改开始:将 response_model 指向新的测试用响应结构 ---
61+
@app.post("/rerank", response_model=TestFinalResponse) # <--- 【关键修改点】response_model 改为 TestFinalResponse
62+
def rerank_endpoint(request: RerankRequest):
63+
# --- 修改结束 ---
64+
65+
pairs = [[request.query, doc] for doc in request.documents]
66+
67+
with torch.no_grad():
68+
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
69+
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
70+
71+
# --- 修改开始:按照测试用的结构来构建结果 ---
72+
results = []
73+
for i, (text, score_val) in enumerate(zip(request.documents, scores)):
74+
75+
# 1. 创建嵌套的 document 对象
76+
doc_info = DocumentInfo(text=text)
77+
78+
# 2. 创建 TestRankResult 对象
79+
# 注意字段名:index, document, score
80+
test_result = TestRankResult(
81+
index=i,
82+
document=doc_info,
83+
score=score_val.item() # <--- 【关键修改点】赋值给 "score" 字段
84+
)
85+
results.append(test_result)
86+
87+
# 3. 排序 (key 也要相应修改为 score)
88+
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
89+
# --- 修改结束 ---
90+
91+
# 返回一个字典,FastAPI 会根据 response_model (TestFinalResponse) 来验证和序列化它
92+
# 最终生成的 JSON 会是 {"results": [{"index": ..., "document": ..., "score": ...}]}
93+
return {"results": sorted_results}
94+
95+
@app.get("/")
96+
def read_root():
97+
return {"status": "Reranker API (Test Version) is running"}
98+
99+
# --- 5. 启动服务 ---
100+
if __name__ == "__main__":
101+
uvicorn.run(app, host="0.0.0.0", port=8000)
102+

0 commit comments

Comments
 (0)