55from contextlib import contextmanager
66import os
77import time
8- from typing import Callable , Optional , TYPE_CHECKING
8+ from typing import Any , Callable , Optional , TYPE_CHECKING
99
1010import numpy as np
1111import requests
@@ -44,6 +44,12 @@ def __init__(
4444 api_key : str = "" ,
4545 timeout : int = _DEFAULT_REMOTE_TIMEOUT_SECONDS ,
4646 default_batch_size : int = 8 ,
47+ local_device : str = "cpu" ,
48+ local_retries : int | None = None ,
49+ local_providers : tuple [tuple [str , str ], ...] = (
50+ ("huggingface" , HUGGINGFACE_ENDPOINT ),
51+ ("modelscope" , MODELSCOPE_ENDPOINT ),
52+ ),
4753 log : Callable [[str ], None ] = _log_default ,
4854 ):
4955 self .model_name = model_name
@@ -52,6 +58,10 @@ def __init__(
5258 self .timeout = max (int (timeout or _DEFAULT_REMOTE_TIMEOUT_SECONDS ), 1 )
5359 self .default_batch_size = max (int (default_batch_size or 1 ), 1 )
5460 self .max_seq_length = None
61+ self .local_device = str (local_device or "cpu" )
62+ self .local_retries = local_retries
63+ self .local_providers = local_providers
64+ self ._local_model = None
5565 self ._log = log
5666
5767 @staticmethod
@@ -71,6 +81,26 @@ def _headers(self) -> dict[str, str]:
7181 headers ["Authorization" ] = f"Bearer { self .api_key } "
7282 return headers
7383
84+ def _get_local_model (self ):
85+ if self ._local_model is None :
86+ self ._log (
87+ f"[WARN] 远程 embedding 不可用,回退本地模型:{ self .model_name } "
88+ f"(device={ self .local_device } )"
89+ )
90+ self ._local_model = _load_local_sentence_transformer (
91+ self .model_name ,
92+ device = self .local_device ,
93+ retries = self .local_retries ,
94+ log = self ._log ,
95+ providers = self .local_providers ,
96+ )
97+ if self .max_seq_length is not None and hasattr (self ._local_model , "max_seq_length" ):
98+ try :
99+ self ._local_model .max_seq_length = self .max_seq_length
100+ except Exception :
101+ pass
102+ return self ._local_model
103+
74104 def encode (
75105 self ,
76106 texts ,
@@ -80,7 +110,6 @@ def encode(
80110 show_progress_bar : bool = False ,
81111 ** kwargs ,
82112 ):
83- del show_progress_bar , kwargs
84113 if isinstance (texts , str ):
85114 texts = [texts ]
86115 if not isinstance (texts , list ):
@@ -90,60 +119,78 @@ def encode(
90119 return empty if convert_to_numpy else empty .tolist ()
91120
92121 safe_batch_size = max (int (batch_size or self .default_batch_size ), 1 )
93- chunks = [texts [i : i + safe_batch_size ] for i in range (0 , len (texts ), safe_batch_size )]
94- outputs : list [np .ndarray ] = []
95-
96- self ._log (
97- f"[INFO] 远程 embedding:model={ self .model_name } "
98- f"endpoint={ self .endpoint } total={ len (texts )} batch={ safe_batch_size } "
99- )
122+ try :
123+ chunks = [texts [i : i + safe_batch_size ] for i in range (0 , len (texts ), safe_batch_size )]
124+ outputs : list [np .ndarray ] = []
100125
101- for chunk_index , chunk in enumerate (chunks , start = 1 ):
102- headers = self ._headers ()
103- response = requests .post (
104- self .endpoint ,
105- headers = headers ,
106- json = {"texts" : chunk },
107- timeout = self .timeout ,
126+ self ._log (
127+ f"[INFO] 远程 embedding:model={ self .model_name } "
128+ f"endpoint={ self .endpoint } total={ len (texts )} batch={ safe_batch_size } "
108129 )
109- if response .status_code == 401 and headers .get ("Authorization" ):
110- self ._log ("[WARN] 远程 embedding 鉴权失败,自动回退为无鉴权请求重试一次。" )
111- headers = {
112- "Content-Type" : "application/json" ,
113- }
130+
131+ for chunk_index , chunk in enumerate (chunks , start = 1 ):
132+ headers = self ._headers ()
114133 response = requests .post (
115134 self .endpoint ,
116135 headers = headers ,
117136 json = {"texts" : chunk },
118137 timeout = self .timeout ,
119138 )
120- response .raise_for_status ()
121- data = response .json ()
122- embeddings = data .get ("embeddings" )
123- if not isinstance (embeddings , list ):
124- raise RuntimeError ("远程 embedding 服务返回缺少 embeddings 字段" )
125- try :
126- arr = np .asarray (embeddings , dtype = np .float32 )
127- except Exception as exc :
128- raise RuntimeError (f"远程 embedding 返回无法转换为 float32:{ exc } " ) from exc
129-
130- if arr .ndim != 2 :
131- raise RuntimeError (f"远程 embedding 返回维度异常:shape={ getattr (arr , 'shape' , None )} " )
132- if arr .shape [0 ] != len (chunk ):
133- raise RuntimeError (
134- f"远程 embedding 返回条数异常:expected={ len (chunk )} actual={ arr .shape [0 ]} "
139+ if response .status_code == 401 and headers .get ("Authorization" ):
140+ self ._log ("[WARN] 远程 embedding 鉴权失败,自动回退为无鉴权请求重试一次。" )
141+ headers = {
142+ "Content-Type" : "application/json" ,
143+ }
144+ response = requests .post (
145+ self .endpoint ,
146+ headers = headers ,
147+ json = {"texts" : chunk },
148+ timeout = self .timeout ,
149+ )
150+ response .raise_for_status ()
151+ data = response .json ()
152+ embeddings = data .get ("embeddings" )
153+ if not isinstance (embeddings , list ):
154+ raise RuntimeError ("远程 embedding 服务返回缺少 embeddings 字段" )
155+ try :
156+ arr = np .asarray (embeddings , dtype = np .float32 )
157+ except Exception as exc :
158+ raise RuntimeError (f"远程 embedding 返回无法转换为 float32:{ exc } " ) from exc
159+
160+ if arr .ndim != 2 :
161+ raise RuntimeError (f"远程 embedding 返回维度异常:shape={ getattr (arr , 'shape' , None )} " )
162+ if arr .shape [0 ] != len (chunk ):
163+ raise RuntimeError (
164+ f"远程 embedding 返回条数异常:expected={ len (chunk )} actual={ arr .shape [0 ]} "
165+ )
166+ if normalize_embeddings :
167+ norms = np .linalg .norm (arr , axis = 1 , keepdims = True )
168+ arr = arr / np .clip (norms , 1e-12 , None )
169+ outputs .append (arr )
170+ self ._log (
171+ f"[INFO] 远程 embedding 批次完成:{ chunk_index } /{ len (chunks )} "
172+ f"count={ len (chunk )} dim={ arr .shape [1 ]} "
135173 )
136- if normalize_embeddings :
137- norms = np .linalg .norm (arr , axis = 1 , keepdims = True )
138- arr = arr / np .clip (norms , 1e-12 , None )
139- outputs .append (arr )
140- self ._log (
141- f"[INFO] 远程 embedding 批次完成:{ chunk_index } /{ len (chunks )} "
142- f"count={ len (chunk )} dim={ arr .shape [1 ]} "
143- )
144174
145- merged = np .vstack (outputs ) if outputs else np .zeros ((0 , 0 ), dtype = np .float32 )
146- return merged if convert_to_numpy else merged .tolist ()
175+ merged = np .vstack (outputs ) if outputs else np .zeros ((0 , 0 ), dtype = np .float32 )
176+ return merged if convert_to_numpy else merged .tolist ()
177+ except Exception as exc :
178+ self ._log (f"[WARN] 远程 embedding 请求失败,将自动回退本地模型:{ exc } " )
179+ local_model = self ._get_local_model ()
180+ result = local_model .encode (
181+ texts ,
182+ convert_to_numpy = convert_to_numpy ,
183+ normalize_embeddings = normalize_embeddings ,
184+ batch_size = safe_batch_size ,
185+ show_progress_bar = show_progress_bar ,
186+ ** kwargs ,
187+ )
188+ if convert_to_numpy and not isinstance (result , np .ndarray ):
189+ try :
190+ result = np .asarray (result , dtype = np .float32 )
191+ except Exception :
192+ pass
193+ return result
147194
148195 def start_multi_process_pool (self , target_devices = None ):
149196 del target_devices
@@ -266,9 +313,32 @@ def load_sentence_transformer(
266313 endpoint = str (remote_endpoint ).strip (),
267314 api_key = remote_api_key ,
268315 timeout = remote_timeout ,
316+ local_device = device ,
317+ local_retries = retries ,
318+ local_providers = providers ,
269319 log = log ,
270320 )
271321
322+ return _load_local_sentence_transformer (
323+ model_name ,
324+ device = device ,
325+ retries = retries ,
326+ log = log ,
327+ providers = providers ,
328+ )
329+
330+
331+ def _load_local_sentence_transformer (
332+ model_name : str ,
333+ * ,
334+ device : str ,
335+ retries : int | None = None ,
336+ log : Callable [[str ], None ] = _log_default ,
337+ providers : tuple [tuple [str , str ], ...] = (
338+ ("huggingface" , HUGGINGFACE_ENDPOINT ),
339+ ("modelscope" , MODELSCOPE_ENDPOINT ),
340+ ),
341+ ):
272342 if retries is None :
273343 env_retries = os .getenv ("LLM_EMBED_MODEL_RETRIES" )
274344 if env_retries is None :
0 commit comments