Skip to content

Commit 202829d

Browse files
committed
[fix] 恢复远程embedding失败时的本地兜底
1 parent 768c6f4 commit 202829d

3 files changed

Lines changed: 151 additions & 55 deletions

File tree

.github/workflows/daily-paper-reader.yml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,21 @@ jobs:
5555
with:
5656
python-version: "3.11"
5757

58-
- name: Cache pip
58+
- name: Cache Python deps
5959
uses: actions/cache@v4
6060
with:
61-
path: ~/.cache/pip
62-
key: ${{ runner.os }}-dpr-remote-pip-v1-${{ hashFiles('requirements.txt') }}
61+
path: |
62+
~/.cache/pip
63+
~/.cache/uv
64+
~/.cache/torch
65+
key: ${{ runner.os }}-dpr-embed-deps-v1-${{ hashFiles('requirements.txt') }}
6366

6467
- name: Install deps (skip sqlite3)
6568
run: |
6669
python - <<'PY'
6770
import re
6871
lines = open("requirements.txt", "r", encoding="utf-8").read().splitlines()
69-
lines = [
70-
l for l in lines
71-
if l.strip()
72-
and not re.match(r"^sqlite3\\b", l)
73-
and not re.match(r"^sentence-transformers\\b", l)
74-
]
72+
lines = [l for l in lines if l.strip() and not re.match(r"^sqlite3\\b", l)]
7573
open("/tmp/req.txt", "w", encoding="utf-8").write("\n".join(lines))
7674
PY
7775
python -m pip install --upgrade pip

src/model_loader.py

Lines changed: 116 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from contextlib import contextmanager
66
import os
77
import time
8-
from typing import Callable, Optional, TYPE_CHECKING
8+
from typing import Any, Callable, Optional, TYPE_CHECKING
99

1010
import numpy as np
1111
import requests
@@ -44,6 +44,12 @@ def __init__(
4444
api_key: str = "",
4545
timeout: int = _DEFAULT_REMOTE_TIMEOUT_SECONDS,
4646
default_batch_size: int = 8,
47+
local_device: str = "cpu",
48+
local_retries: int | None = None,
49+
local_providers: tuple[tuple[str, str], ...] = (
50+
("huggingface", HUGGINGFACE_ENDPOINT),
51+
("modelscope", MODELSCOPE_ENDPOINT),
52+
),
4753
log: Callable[[str], None] = _log_default,
4854
):
4955
self.model_name = model_name
@@ -52,6 +58,10 @@ def __init__(
5258
self.timeout = max(int(timeout or _DEFAULT_REMOTE_TIMEOUT_SECONDS), 1)
5359
self.default_batch_size = max(int(default_batch_size or 1), 1)
5460
self.max_seq_length = None
61+
self.local_device = str(local_device or "cpu")
62+
self.local_retries = local_retries
63+
self.local_providers = local_providers
64+
self._local_model = None
5565
self._log = log
5666

5767
@staticmethod
@@ -71,6 +81,26 @@ def _headers(self) -> dict[str, str]:
7181
headers["Authorization"] = f"Bearer {self.api_key}"
7282
return headers
7383

84+
def _get_local_model(self):
85+
if self._local_model is None:
86+
self._log(
87+
f"[WARN] 远程 embedding 不可用,回退本地模型:{self.model_name} "
88+
f"(device={self.local_device})"
89+
)
90+
self._local_model = _load_local_sentence_transformer(
91+
self.model_name,
92+
device=self.local_device,
93+
retries=self.local_retries,
94+
log=self._log,
95+
providers=self.local_providers,
96+
)
97+
if self.max_seq_length is not None and hasattr(self._local_model, "max_seq_length"):
98+
try:
99+
self._local_model.max_seq_length = self.max_seq_length
100+
except Exception:
101+
pass
102+
return self._local_model
103+
74104
def encode(
75105
self,
76106
texts,
@@ -80,7 +110,6 @@ def encode(
80110
show_progress_bar: bool = False,
81111
**kwargs,
82112
):
83-
del show_progress_bar, kwargs
84113
if isinstance(texts, str):
85114
texts = [texts]
86115
if not isinstance(texts, list):
@@ -90,60 +119,78 @@ def encode(
90119
return empty if convert_to_numpy else empty.tolist()
91120

92121
safe_batch_size = max(int(batch_size or self.default_batch_size), 1)
93-
chunks = [texts[i : i + safe_batch_size] for i in range(0, len(texts), safe_batch_size)]
94-
outputs: list[np.ndarray] = []
95-
96-
self._log(
97-
f"[INFO] 远程 embedding:model={self.model_name} "
98-
f"endpoint={self.endpoint} total={len(texts)} batch={safe_batch_size}"
99-
)
122+
try:
123+
chunks = [texts[i : i + safe_batch_size] for i in range(0, len(texts), safe_batch_size)]
124+
outputs: list[np.ndarray] = []
100125

101-
for chunk_index, chunk in enumerate(chunks, start=1):
102-
headers = self._headers()
103-
response = requests.post(
104-
self.endpoint,
105-
headers=headers,
106-
json={"texts": chunk},
107-
timeout=self.timeout,
126+
self._log(
127+
f"[INFO] 远程 embedding:model={self.model_name} "
128+
f"endpoint={self.endpoint} total={len(texts)} batch={safe_batch_size}"
108129
)
109-
if response.status_code == 401 and headers.get("Authorization"):
110-
self._log("[WARN] 远程 embedding 鉴权失败,自动回退为无鉴权请求重试一次。")
111-
headers = {
112-
"Content-Type": "application/json",
113-
}
130+
131+
for chunk_index, chunk in enumerate(chunks, start=1):
132+
headers = self._headers()
114133
response = requests.post(
115134
self.endpoint,
116135
headers=headers,
117136
json={"texts": chunk},
118137
timeout=self.timeout,
119138
)
120-
response.raise_for_status()
121-
data = response.json()
122-
embeddings = data.get("embeddings")
123-
if not isinstance(embeddings, list):
124-
raise RuntimeError("远程 embedding 服务返回缺少 embeddings 字段")
125-
try:
126-
arr = np.asarray(embeddings, dtype=np.float32)
127-
except Exception as exc:
128-
raise RuntimeError(f"远程 embedding 返回无法转换为 float32:{exc}") from exc
129-
130-
if arr.ndim != 2:
131-
raise RuntimeError(f"远程 embedding 返回维度异常:shape={getattr(arr, 'shape', None)}")
132-
if arr.shape[0] != len(chunk):
133-
raise RuntimeError(
134-
f"远程 embedding 返回条数异常:expected={len(chunk)} actual={arr.shape[0]}"
139+
if response.status_code == 401 and headers.get("Authorization"):
140+
self._log("[WARN] 远程 embedding 鉴权失败,自动回退为无鉴权请求重试一次。")
141+
headers = {
142+
"Content-Type": "application/json",
143+
}
144+
response = requests.post(
145+
self.endpoint,
146+
headers=headers,
147+
json={"texts": chunk},
148+
timeout=self.timeout,
149+
)
150+
response.raise_for_status()
151+
data = response.json()
152+
embeddings = data.get("embeddings")
153+
if not isinstance(embeddings, list):
154+
raise RuntimeError("远程 embedding 服务返回缺少 embeddings 字段")
155+
try:
156+
arr = np.asarray(embeddings, dtype=np.float32)
157+
except Exception as exc:
158+
raise RuntimeError(f"远程 embedding 返回无法转换为 float32:{exc}") from exc
159+
160+
if arr.ndim != 2:
161+
raise RuntimeError(f"远程 embedding 返回维度异常:shape={getattr(arr, 'shape', None)}")
162+
if arr.shape[0] != len(chunk):
163+
raise RuntimeError(
164+
f"远程 embedding 返回条数异常:expected={len(chunk)} actual={arr.shape[0]}"
165+
)
166+
if normalize_embeddings:
167+
norms = np.linalg.norm(arr, axis=1, keepdims=True)
168+
arr = arr / np.clip(norms, 1e-12, None)
169+
outputs.append(arr)
170+
self._log(
171+
f"[INFO] 远程 embedding 批次完成:{chunk_index}/{len(chunks)} "
172+
f"count={len(chunk)} dim={arr.shape[1]}"
135173
)
136-
if normalize_embeddings:
137-
norms = np.linalg.norm(arr, axis=1, keepdims=True)
138-
arr = arr / np.clip(norms, 1e-12, None)
139-
outputs.append(arr)
140-
self._log(
141-
f"[INFO] 远程 embedding 批次完成:{chunk_index}/{len(chunks)} "
142-
f"count={len(chunk)} dim={arr.shape[1]}"
143-
)
144174

145-
merged = np.vstack(outputs) if outputs else np.zeros((0, 0), dtype=np.float32)
146-
return merged if convert_to_numpy else merged.tolist()
175+
merged = np.vstack(outputs) if outputs else np.zeros((0, 0), dtype=np.float32)
176+
return merged if convert_to_numpy else merged.tolist()
177+
except Exception as exc:
178+
self._log(f"[WARN] 远程 embedding 请求失败,将自动回退本地模型:{exc}")
179+
local_model = self._get_local_model()
180+
result = local_model.encode(
181+
texts,
182+
convert_to_numpy=convert_to_numpy,
183+
normalize_embeddings=normalize_embeddings,
184+
batch_size=safe_batch_size,
185+
show_progress_bar=show_progress_bar,
186+
**kwargs,
187+
)
188+
if convert_to_numpy and not isinstance(result, np.ndarray):
189+
try:
190+
result = np.asarray(result, dtype=np.float32)
191+
except Exception:
192+
pass
193+
return result
147194

148195
def start_multi_process_pool(self, target_devices=None):
149196
del target_devices
@@ -266,9 +313,32 @@ def load_sentence_transformer(
266313
endpoint=str(remote_endpoint).strip(),
267314
api_key=remote_api_key,
268315
timeout=remote_timeout,
316+
local_device=device,
317+
local_retries=retries,
318+
local_providers=providers,
269319
log=log,
270320
)
271321

322+
return _load_local_sentence_transformer(
323+
model_name,
324+
device=device,
325+
retries=retries,
326+
log=log,
327+
providers=providers,
328+
)
329+
330+
331+
def _load_local_sentence_transformer(
332+
model_name: str,
333+
*,
334+
device: str,
335+
retries: int | None = None,
336+
log: Callable[[str], None] = _log_default,
337+
providers: tuple[tuple[str, str], ...] = (
338+
("huggingface", HUGGINGFACE_ENDPOINT),
339+
("modelscope", MODELSCOPE_ENDPOINT),
340+
),
341+
):
272342
if retries is None:
273343
env_retries = os.getenv("LLM_EMBED_MODEL_RETRIES")
274344
if env_retries is None:

tests/test_model_loader_remote.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import MagicMock, patch
44

55
import numpy as np
6+
import requests
67

78
from src.model_loader import RemoteSentenceTransformer, load_sentence_transformer
89

@@ -51,6 +52,33 @@ def test_remote_encode_batches_and_normalizes(self, mock_post):
5152
self.assertEqual(first_call.kwargs["headers"]["Authorization"], "Bearer test-key")
5253
self.assertEqual(first_call.kwargs["timeout"], 30)
5354

55+
@patch("src.model_loader._load_local_sentence_transformer")
56+
@patch("src.model_loader.requests.post")
57+
def test_remote_encode_falls_back_to_local_model_when_remote_fails(self, mock_post, mock_load_local):
58+
mock_post.side_effect = requests.exceptions.Timeout("remote timeout")
59+
local_model = MagicMock()
60+
local_model.encode.return_value = np.asarray([[0.1, 0.2]], dtype=np.float32)
61+
mock_load_local.return_value = local_model
62+
63+
model = RemoteSentenceTransformer(
64+
model_name="BAAI/bge-small-en-v1.5",
65+
endpoint="https://embed.zwwen.online",
66+
api_key="test-key",
67+
timeout=30,
68+
default_batch_size=2,
69+
)
70+
arr = model.encode(
71+
["a"],
72+
convert_to_numpy=True,
73+
normalize_embeddings=True,
74+
batch_size=2,
75+
)
76+
77+
self.assertEqual(mock_post.call_count, 1)
78+
mock_load_local.assert_called_once()
79+
local_model.encode.assert_called_once()
80+
self.assertEqual(arr.shape, (1, 2))
81+
5482
@patch.dict(
5583
os.environ,
5684
{

0 commit comments

Comments
 (0)