Skip to content

Commit 768c6f4

Browse files
committed
[beautify] 优化远程embedding模式日志与依赖安装
1 parent ce58f28 commit 768c6f4

3 files changed

Lines changed: 45 additions & 20 deletions

File tree

.github/workflows/daily-paper-reader.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,23 @@ jobs:
5555
with:
5656
python-version: "3.11"
5757

58-
- name: Cache pip + torch
58+
- name: Cache pip
5959
uses: actions/cache@v4
6060
with:
61-
path: |
62-
~/.cache/pip
63-
~/.cache/torch
64-
key: ${{ runner.os }}-dpr-hf-v2-${{ hashFiles('requirements.txt') }}
61+
path: ~/.cache/pip
62+
key: ${{ runner.os }}-dpr-remote-pip-v1-${{ hashFiles('requirements.txt') }}
6563

6664
- name: Install deps (skip sqlite3)
6765
run: |
6866
python - <<'PY'
6967
import re
7068
lines = open("requirements.txt", "r", encoding="utf-8").read().splitlines()
71-
lines = [l for l in lines if l.strip() and not re.match(r"^sqlite3\\b", l)]
69+
lines = [
70+
l for l in lines
71+
if l.strip()
72+
and not re.match(r"^sqlite3\\b", l)
73+
and not re.match(r"^sentence-transformers\\b", l)
74+
]
7275
open("/tmp/req.txt", "w", encoding="utf-8").write("\n".join(lines))
7376
PY
7477
python -m pip install --upgrade pip

src/filter.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#!/usr/bin/env python
22
# 通用向量检索工具:封装 sentence-transformers 的向量计算与粗筛逻辑
33

4+
from __future__ import annotations
5+
46
import os
57
import numpy as np
6-
from typing import Any, Dict, List
8+
from typing import Any, Dict, List, TYPE_CHECKING
79
import time
810
from datetime import datetime, timezone
911

1012
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
1113

12-
import torch
13-
from sentence_transformers import SentenceTransformer
14+
from model_loader import is_remote_embedding_enabled, load_sentence_transformer
1415

15-
from model_loader import load_sentence_transformer
16+
if TYPE_CHECKING:
17+
from sentence_transformers import SentenceTransformer
1618

1719
# E5 系列推荐使用 query/passsage 前缀来区分检索侧与文档侧
1820
E5_QUERY_PREFIX = "query: "
@@ -30,6 +32,8 @@ def debug_hf_runtime(prefix: str) -> None:
3032
enable = (os.getenv("DPR_DEBUG_HF") == "1") or (os.getenv("GITHUB_ACTIONS") == "true")
3133
if not enable:
3234
return
35+
if is_remote_embedding_enabled():
36+
return
3337

3438
log(f"[DEBUG][HF] {prefix}")
3539
keys = [
@@ -73,7 +77,7 @@ def ls_dir(path: str) -> None:
7377
ls_dir(hf_home)
7478

7579

76-
def _set_max_seq_length(model: SentenceTransformer, max_length: int | None) -> None:
80+
def _set_max_seq_length(model: Any, max_length: int | None) -> None:
7781
"""尽量通过 SentenceTransformer 的 max_seq_length 控制截断长度。"""
7882
if max_length is None or max_length <= 0:
7983
return
@@ -93,7 +97,7 @@ def _set_max_seq_length(model: SentenceTransformer, max_length: int | None) -> N
9397

9498

9599
def encode_queries(
96-
model: SentenceTransformer,
100+
model: Any,
97101
texts: List[str],
98102
batch_size: int = 8,
99103
max_length: int | None = None,
@@ -128,7 +132,7 @@ def encode_queries(
128132

129133

130134
def compute_embeddings(
131-
model: SentenceTransformer,
135+
model: Any,
132136
items: List[Any],
133137
batch_size: int = 8,
134138
max_length: int | None = None,
@@ -206,15 +210,27 @@ def __init__(
206210
self.batch_size = batch_size
207211
self.max_length = max_length
208212

213+
remote_mode = is_remote_embedding_enabled()
209214
if device is None:
210-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
215+
if remote_mode:
216+
self.device = "remote"
217+
else:
218+
try:
219+
import torch
220+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
221+
except Exception:
222+
self.device = "cpu"
211223
else:
212-
self.device = device
224+
self.device = device if not remote_mode else "remote"
213225

214-
print(f"[INFO] 正在加载向量模型:{self.model_name},device={self.device}")
215-
debug_hf_runtime("before SentenceTransformer()")
226+
if remote_mode:
227+
print(f"[INFO] 正在初始化远程向量服务:{self.model_name},device={self.device}")
228+
else:
229+
print(f"[INFO] 正在加载本地向量模型:{self.model_name},device={self.device}")
230+
debug_hf_runtime("before SentenceTransformer()")
216231
self.model = load_sentence_transformer(self.model_name, device=self.device)
217-
debug_hf_runtime("after SentenceTransformer()")
232+
if not remote_mode:
233+
debug_hf_runtime("after SentenceTransformer()")
218234
_set_max_seq_length(self.model, self.max_length)
219235

220236
def filter(self, items: List[Any], queries: List[Dict[str, Any]]) -> Dict[str, Any]:

src/model_loader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from contextlib import contextmanager
66
import os
77
import time
8-
from typing import Callable, Optional
8+
from typing import Callable, Optional, TYPE_CHECKING
99

1010
import numpy as np
1111
import requests
1212

13-
from sentence_transformers import SentenceTransformer
13+
if TYPE_CHECKING:
14+
from sentence_transformers import SentenceTransformer
1415

1516

1617
HUGGINGFACE_ENDPOINT = "https://huggingface.co"
@@ -27,6 +28,10 @@ def _log_default(message: str) -> None:
2728
print(message, flush=True)
2829

2930

31+
def is_remote_embedding_enabled() -> bool:
32+
return bool(str(_DEFAULT_REMOTE_EMBED_ENDPOINT or "").strip())
33+
34+
3035
class RemoteSentenceTransformer:
3136
"""兼容 SentenceTransformer.encode 接口的远程 embedding 包装器。"""
3237

@@ -299,6 +304,7 @@ def load_sentence_transformer(
299304
f"(provider={provider_name},device={device})"
300305
)
301306
with _hf_endpoint(endpoint), _hf_http_backoff(max_retries=hf_backoff_retries):
307+
from sentence_transformers import SentenceTransformer
302308
return SentenceTransformer(model_name, device=device)
303309
except Exception as e: # pragma: no cover - 仅异常路径
304310
last_err = e

0 commit comments

Comments
 (0)