1+ import torch
2+ import uvicorn
3+ from fastapi import FastAPI
4+ from pydantic import BaseModel , Field
5+ from transformers import AutoModelForSequenceClassification , AutoTokenizer
6+ from typing import List
7+
8+ # --- 1. 定义API的请求和响应数据结构 ---
9+
10+ # 请求体结构保持不变
11+ class RerankRequest (BaseModel ):
12+ query : str
13+ documents : List [str ]
14+
15+ # --- 修改开始:定义测试用的响应结构,字段名为 "score" ---
16+
17+ # DocumentInfo 结构保持不变
18+ class DocumentInfo (BaseModel ):
19+ text : str
20+
21+ # 将原来的 GoRankResult 修改为 TestRankResult
22+ # 核心改动:将 "relevance_score" 字段重命名为 "score"
23+ class TestRankResult (BaseModel ):
24+ index : int
25+ document : DocumentInfo
26+ score : float # <--- 【关键修改点】字段名已从 relevance_score 改为 score
27+
28+ # 最终响应体结构,其 "results" 列表包含的是 TestRankResult
29+ class TestFinalResponse (BaseModel ):
30+ results : List [TestRankResult ]
31+
32+ # --- 修改结束 ---
33+
34+
35+ # --- 2. 加载模型 (在服务启动时执行一次) ---
36+ print ("正在加载模型,请稍候..." )
37+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
38+ print (f"使用的设备: { device } " )
39+ try :
40+ # 请确保这里的路径是正确的
41+ model_path = '/data1/home/lwx/work/Download/rerank_model_weight'
42+ tokenizer = AutoTokenizer .from_pretrained (model_path )
43+ model = AutoModelForSequenceClassification .from_pretrained (model_path )
44+ model .to (device )
45+ model .eval ()
46+ print ("模型加载成功!" )
47+ except Exception as e :
48+ print (f"模型加载失败: { e } " )
49+ # 在测试环境中,如果模型加载失败,可以考虑退出以避免运行一个无效的服务
50+ exit ()
51+
52+ # --- 3. 创建FastAPI应用 ---
53+ app = FastAPI (
54+ title = "Reranker API (Test Version)" ,
55+ description = "一个返回 'score' 字段以测试Go客户端兼容性的API服务" ,
56+ version = "1.0.1"
57+ )
58+
59+ # --- 4. 定义API端点 ---
60+ # --- 修改开始:将 response_model 指向新的测试用响应结构 ---
61+ @app .post ("/rerank" , response_model = TestFinalResponse ) # <--- 【关键修改点】response_model 改为 TestFinalResponse
62+ def rerank_endpoint (request : RerankRequest ):
63+ # --- 修改结束 ---
64+
65+ pairs = [[request .query , doc ] for doc in request .documents ]
66+
67+ with torch .no_grad ():
68+ inputs = tokenizer (pairs , padding = True , truncation = True , return_tensors = 'pt' , max_length = 1024 ).to (device )
69+ scores = model (** inputs , return_dict = True ).logits .view (- 1 , ).float ()
70+
71+ # --- 修改开始:按照测试用的结构来构建结果 ---
72+ results = []
73+ for i , (text , score_val ) in enumerate (zip (request .documents , scores )):
74+
75+ # 1. 创建嵌套的 document 对象
76+ doc_info = DocumentInfo (text = text )
77+
78+ # 2. 创建 TestRankResult 对象
79+ # 注意字段名:index, document, score
80+ test_result = TestRankResult (
81+ index = i ,
82+ document = doc_info ,
83+ score = score_val .item () # <--- 【关键修改点】赋值给 "score" 字段
84+ )
85+ results .append (test_result )
86+
87+ # 3. 排序 (key 也要相应修改为 score)
88+ sorted_results = sorted (results , key = lambda x : x .score , reverse = True )
89+ # --- 修改结束 ---
90+
91+ # 返回一个字典,FastAPI 会根据 response_model (TestFinalResponse) 来验证和序列化它
92+ # 最终生成的 JSON 会是 {"results": [{"index": ..., "document": ..., "score": ...}]}
93+ return {"results" : sorted_results }
94+
95+ @app .get ("/" )
96+ def read_root ():
97+ return {"status" : "Reranker API (Test Version) is running" }
98+
99+ # --- 5. 启动服务 ---
100+ if __name__ == "__main__" :
101+ uvicorn .run (app , host = "0.0.0.0" , port = 8000 )
102+
0 commit comments