Skip to content

Commit 573b973

Browse files
authored
add three embed method (#34)
* add auto embedding method selection * revise three method
1 parent d21b963 commit 573b973

2 files changed

Lines changed: 146 additions & 57 deletions

File tree

src/dataflex/offline_selector/offline_near_selector.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
import numpy as np
44
import faiss
55
import heapq
6+
# ===== auto optional embedding backends =====
67
try:
78
from vllm import LLM, SamplingParams
89
VLLM_AVAILABLE = True
9-
except ImportError:
10+
except Exception:
1011
VLLM_AVAILABLE = False
11-
from sentence_transformers import SentenceTransformer
12+
13+
try:
14+
from sentence_transformers import SentenceTransformer
15+
ST_AVAILABLE = True
16+
except Exception:
17+
ST_AVAILABLE = False
1218
from dataflex.utils.logging import logger
1319

1420
# ========== FAISS IVFFlat 索引封装类 ==========
@@ -36,13 +42,15 @@ def __init__(self,
3642
candidate_path = None,
3743
query_path: str = None,
3844
embed_model: str = "Qwen/Qwen3-Embedding-0.6B",
45+
embed_method: str ="auto",
3946
batch_size: int = 32,
4047
save_indices_path: str = "top_indices.npy",
4148
max_K: int = 1000):
42-
49+
4350
self.candidate_path = candidate_path
4451
self.query_path = query_path
4552
self.embed_model = embed_model
53+
self.embed_method = embed_method
4654
self.batch_size = batch_size
4755
self.save_indices_path = save_indices_path
4856
self.max_K = max_K
@@ -64,28 +72,62 @@ def _load_alpaca_json(self, path):
6472

6573
# ---------- Embedding 方法 ----------
6674
def _embed_texts(self, texts):
67-
if VLLM_AVAILABLE and self.embed_model.startswith("vllm:"):
68-
model_name = self.embed_model.replace("vllm:", "")
69-
logger.info(f"[EMBED] vLLM model: {model_name}")
70-
llm = LLM(model=model_name, trust_remote_code=True, task="embed")
71-
72-
# 使用 vLLM 的 embed 接口
73-
outputs = llm.embed(texts) # 返回 [N, D]
74-
print(f"Embeddings shape: {np.array(outputs).shape}", outputs[0])
75-
embs = [o.outputs.embedding for o in outputs]
76-
embs = np.array(embs, dtype=np.float32)
77-
else:
78-
logger.info(f"[EMBED] SentenceTransformer: {self.embed_model}")
79-
model = SentenceTransformer(self.embed_model)
80-
embs = model.encode(texts,
81-
batch_size=self.batch_size,
82-
show_progress_bar=True).astype(np.float32)
83-
norms = np.linalg.norm(embs, axis=1, keepdims=True) # [N, 1]
84-
# 防止除以 0
85-
norms = np.maximum(norms, 1e-12)
86-
embs = embs / norms
87-
# --------------------------------------
88-
return np.ascontiguousarray(embs)
75+
'''
76+
auto模式自动尝试 embedding 后端:
77+
1) 优先 vLLM
78+
2) 否则 sentence-transformers
79+
3) 都不可用则报错
80+
'''
81+
82+
# -------- 1. 优先 vLLM --------
83+
if (VLLM_AVAILABLE and self.embed_method == "auto") or self.embed_method == "vllm":
84+
try:
85+
logger.info(f"[EMBED] Using vLLM model: {self.embed_model}")
86+
llm = LLM(model=self.embed_model, trust_remote_code=True, task="embed")
87+
88+
outputs = llm.embed(texts) # [N, D]
89+
embs = [o.outputs.embedding for o in outputs]
90+
embs = np.array(embs, dtype=np.float32)
91+
92+
# normalize
93+
norms = np.linalg.norm(embs, axis=1, keepdims=True)
94+
norms = np.maximum(norms, 1e-12)
95+
embs = embs / norms
96+
97+
return np.ascontiguousarray(embs)
98+
99+
except Exception as e:
100+
logger.warning(f"[EMBED] vLLM available but embedding failed {e}")
101+
102+
# -------- 2. fallback: sentence-transformers --------
103+
if (ST_AVAILABLE and self.embed_method == "auto") or self.embed_method == "sentence-transformer":
104+
try:
105+
logger.info(f"[EMBED] Using SentenceTransformer: {self.embed_model}")
106+
model = SentenceTransformer(self.embed_model)
107+
embs = model.encode(
108+
texts,
109+
batch_size=self.batch_size,
110+
show_progress_bar=True
111+
).astype(np.float32)
112+
113+
norms = np.linalg.norm(embs, axis=1, keepdims=True)
114+
norms = np.maximum(norms, 1e-12)
115+
embs = embs / norms
116+
117+
return np.ascontiguousarray(embs)
118+
119+
except Exception as e:
120+
raise RuntimeError(
121+
f"SentenceTransformer available but embedding failed: {e}"
122+
)
123+
124+
# -------- 3. 两个都不可用 --------
125+
raise RuntimeError(
126+
"No available embedding backend!\n"
127+
"Please install at least one of the following:\n"
128+
" - vLLM: pip install vllm\n"
129+
" - sentence-transformers: pip install sentence-transformers"
130+
)
89131

90132
# ---------- 调用接口 ----------
91133
def candidate_sentence_embedding(self):
@@ -138,10 +180,13 @@ def selector(self):
138180
near = offline_near_Selector(
139181
candidate_path="OpenDCAI/DataFlex-selector-openhermes-10w", # split = train
140182
query_path="OpenDCAI/DataFlex-selector-openhermes-10w", # split = vaildation
141-
142-
# If you want to use vllm,please add "vllm:" before model's name
143-
# Otherwise it automatically use sentence-transfromer
144-
embed_model="vllm:Qwen/Qwen3-Embedding-0.6B",
183+
# It automatically try vllm first, then sentence-transformers
184+
embed_model="Qwen/Qwen3-Embedding-0.6B",
185+
# support method:
186+
#auto(It automatically try vllm first, then sentence-transformers),
187+
#vllm,
188+
#sentence-transformer
189+
embed_method= "auto",
145190
batch_size=32,
146191
save_indices_path="top_indices.npy",
147192
max_K=1000,

src/dataflex/offline_selector/offline_tsds_selector.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
import numpy as np
44
import faiss
55
import heapq
6+
# ===== auto optional embedding backends =====
67
try:
78
from vllm import LLM, SamplingParams
89
VLLM_AVAILABLE = True
9-
except ImportError:
10+
except Exception:
1011
VLLM_AVAILABLE = False
11-
from sentence_transformers import SentenceTransformer
12+
13+
try:
14+
from sentence_transformers import SentenceTransformer
15+
ST_AVAILABLE = True
16+
except Exception:
17+
ST_AVAILABLE = False
1218
from dataflex.utils.logging import logger
1319

1420
# ========== FAISS IVFFlat 索引封装类 ==========
@@ -36,6 +42,7 @@ def __init__(self,
3642
candidate_path = None,
3743
query_path: str = None,
3844
embed_model: str = "Qwen/Qwen3-Embedding-0.6B",
45+
embed_method: str ="auto",
3946
batch_size: int = 32,
4047
save_probs_path: str = "tsds_probs.npy",
4148
max_K: int = 5000,
@@ -47,6 +54,7 @@ def __init__(self,
4754
self.candidate_path = candidate_path
4855
self.query_path = query_path
4956
self.embed_model = embed_model
57+
self.embed_method = embed_method
5058
self.batch_size = batch_size
5159
self.save_probs_path = save_probs_path
5260
self.max_K = max_K
@@ -72,28 +80,62 @@ def _load_alpaca_json(self, path):
7280

7381
# ---------- Embedding 方法 ----------
7482
def _embed_texts(self, texts):
75-
if VLLM_AVAILABLE and self.embed_model.startswith("vllm:"):
76-
model_name = self.embed_model.replace("vllm:", "")
77-
logger.info(f"[EMBED] vLLM model: {model_name}")
78-
llm = LLM(model=model_name, trust_remote_code=True, task="embed")
79-
80-
# 使用 vLLM 的 embed 接口
81-
outputs = llm.embed(texts) # 返回 [N, D]
82-
print(f"Embeddings shape: {np.array(outputs).shape}", outputs[0])
83-
embs = [o.outputs.embedding for o in outputs]
84-
embs = np.array(embs, dtype=np.float32)
85-
else:
86-
logger.info(f"[EMBED] SentenceTransformer: {self.embed_model}")
87-
model = SentenceTransformer(self.embed_model)
88-
embs = model.encode(texts,
89-
batch_size=self.batch_size,
90-
show_progress_bar=True).astype(np.float32)
91-
norms = np.linalg.norm(embs, axis=1, keepdims=True) # [N, 1]
92-
# 防止除以 0
93-
norms = np.maximum(norms, 1e-12)
94-
embs = embs / norms
95-
# --------------------------------------
96-
return np.ascontiguousarray(embs)
83+
'''
84+
auto模式自动尝试 embedding 后端:
85+
1) 优先 vLLM
86+
2) 否则 sentence-transformers
87+
3) 都不可用则报错
88+
'''
89+
90+
# -------- 1. 优先 vLLM --------
91+
if (VLLM_AVAILABLE and self.embed_method == "auto") or self.embed_method == "vllm":
92+
try:
93+
logger.info(f"[EMBED] Using vLLM model: {self.embed_model}")
94+
llm = LLM(model=self.embed_model, trust_remote_code=True, task="embed")
95+
96+
outputs = llm.embed(texts) # [N, D]
97+
embs = [o.outputs.embedding for o in outputs]
98+
embs = np.array(embs, dtype=np.float32)
99+
100+
# normalize
101+
norms = np.linalg.norm(embs, axis=1, keepdims=True)
102+
norms = np.maximum(norms, 1e-12)
103+
embs = embs / norms
104+
105+
return np.ascontiguousarray(embs)
106+
107+
except Exception as e:
108+
logger.warning(f"[EMBED] vLLM available but embedding failed {e}")
109+
110+
# -------- 2. fallback: sentence-transformers --------
111+
if (ST_AVAILABLE and self.embed_method == "auto") or self.embed_method == "sentence-transformer":
112+
try:
113+
logger.info(f"[EMBED] Using SentenceTransformer: {self.embed_model}")
114+
model = SentenceTransformer(self.embed_model)
115+
embs = model.encode(
116+
texts,
117+
batch_size=self.batch_size,
118+
show_progress_bar=True
119+
).astype(np.float32)
120+
121+
norms = np.linalg.norm(embs, axis=1, keepdims=True)
122+
norms = np.maximum(norms, 1e-12)
123+
embs = embs / norms
124+
125+
return np.ascontiguousarray(embs)
126+
127+
except Exception as e:
128+
raise RuntimeError(
129+
f"SentenceTransformer available but embedding failed: {e}"
130+
)
131+
132+
# -------- 3. 两个都不可用 --------
133+
raise RuntimeError(
134+
"No available embedding backend!\n"
135+
"Please install at least one of the following:\n"
136+
" - vLLM: pip install vllm\n"
137+
" - sentence-transformers: pip install sentence-transformers"
138+
)
97139

98140
# ---------- TSDS 调用接口 ----------
99141
def candidate_sentence_embedding(self):
@@ -192,10 +234,12 @@ def selector(self):
192234
tsds = offline_tsds_Selector(
193235
candidate_path="OpenDCAI/DataFlex-selector-openhermes-10w",
194236
query_path="OpenDCAI/DataFlex-selector-openhermes-10w",
195-
196-
# If you want to use vllm,please add "vllm:" before model's name
197-
# Otherwise it automatically use sentence-transfromer
198-
embed_model="vllm:Qwen/Qwen3-Embedding-0.6B",
237+
embed_model="Qwen/Qwen3-Embedding-0.6B",
238+
# support method:
239+
#auto(It automatically try vllm first, then sentence-transformers),
240+
#vllm,
241+
#sentence-transformer
242+
embed_method="auto",
199243
batch_size=32,
200244
save_probs_path="tsds_probs.npy",
201245
max_K=5000,

0 commit comments

Comments
 (0)