11#!/usr/bin/env python
22# 通用向量检索工具:封装 sentence-transformers 的向量计算与粗筛逻辑
33
4+ from __future__ import annotations
5+
46import os
57import numpy as np
6- from typing import Any , Dict , List
8+ from typing import Any , Dict , List , TYPE_CHECKING
79import time
810from datetime import datetime , timezone
911
1012os .environ .setdefault ("HF_HUB_DISABLE_SYMLINKS" , "1" )
1113
12- import torch
13- from sentence_transformers import SentenceTransformer
14+ from model_loader import is_remote_embedding_enabled , load_sentence_transformer
1415
15- from model_loader import load_sentence_transformer
16+ if TYPE_CHECKING :
17+ from sentence_transformers import SentenceTransformer
1618
1719# E5 系列推荐使用 query/passsage 前缀来区分检索侧与文档侧
1820E5_QUERY_PREFIX = "query: "
@@ -30,6 +32,8 @@ def debug_hf_runtime(prefix: str) -> None:
3032 enable = (os .getenv ("DPR_DEBUG_HF" ) == "1" ) or (os .getenv ("GITHUB_ACTIONS" ) == "true" )
3133 if not enable :
3234 return
35+ if is_remote_embedding_enabled ():
36+ return
3337
3438 log (f"[DEBUG][HF] { prefix } " )
3539 keys = [
@@ -73,7 +77,7 @@ def ls_dir(path: str) -> None:
7377 ls_dir (hf_home )
7478
7579
76- def _set_max_seq_length (model : SentenceTransformer , max_length : int | None ) -> None :
80+ def _set_max_seq_length (model : Any , max_length : int | None ) -> None :
7781 """尽量通过 SentenceTransformer 的 max_seq_length 控制截断长度。"""
7882 if max_length is None or max_length <= 0 :
7983 return
@@ -93,7 +97,7 @@ def _set_max_seq_length(model: SentenceTransformer, max_length: int | None) -> N
9397
9498
9599def encode_queries (
96- model : SentenceTransformer ,
100+ model : Any ,
97101 texts : List [str ],
98102 batch_size : int = 8 ,
99103 max_length : int | None = None ,
@@ -128,7 +132,7 @@ def encode_queries(
128132
129133
130134def compute_embeddings (
131- model : SentenceTransformer ,
135+ model : Any ,
132136 items : List [Any ],
133137 batch_size : int = 8 ,
134138 max_length : int | None = None ,
@@ -206,15 +210,27 @@ def __init__(
206210 self .batch_size = batch_size
207211 self .max_length = max_length
208212
213+ remote_mode = is_remote_embedding_enabled ()
209214 if device is None :
210- self .device = "cuda" if torch .cuda .is_available () else "cpu"
215+ if remote_mode :
216+ self .device = "remote"
217+ else :
218+ try :
219+ import torch
220+ self .device = "cuda" if torch .cuda .is_available () else "cpu"
221+ except Exception :
222+ self .device = "cpu"
211223 else :
212- self .device = device
224+ self .device = device if not remote_mode else "remote"
213225
214- print (f"[INFO] 正在加载向量模型:{ self .model_name } ,device={ self .device } " )
215- debug_hf_runtime ("before SentenceTransformer()" )
226+ if remote_mode :
227+ print (f"[INFO] 正在初始化远程向量服务:{ self .model_name } ,device={ self .device } " )
228+ else :
229+ print (f"[INFO] 正在加载本地向量模型:{ self .model_name } ,device={ self .device } " )
230+ debug_hf_runtime ("before SentenceTransformer()" )
216231 self .model = load_sentence_transformer (self .model_name , device = self .device )
217- debug_hf_runtime ("after SentenceTransformer()" )
232+ if not remote_mode :
233+ debug_hf_runtime ("after SentenceTransformer()" )
218234 _set_max_seq_length (self .model , self .max_length )
219235
220236 def filter (self , items : List [Any ], queries : List [Dict [str , Any ]]) -> Dict [str , Any ]:
0 commit comments