33import numpy as np
44import faiss
55import heapq
6+ # ===== auto optional embedding backends =====
67try :
78 from vllm import LLM , SamplingParams
89 VLLM_AVAILABLE = True
9- except ImportError :
10+ except Exception :
1011 VLLM_AVAILABLE = False
11- from sentence_transformers import SentenceTransformer
12+
13+ try :
14+ from sentence_transformers import SentenceTransformer
15+ ST_AVAILABLE = True
16+ except Exception :
17+ ST_AVAILABLE = False
1218from dataflex .utils .logging import logger
1319
1420# ========== FAISS IVFFlat 索引封装类 ==========
@@ -36,6 +42,7 @@ def __init__(self,
3642 candidate_path = None ,
3743 query_path : str = None ,
3844 embed_model : str = "Qwen/Qwen3-Embedding-0.6B" ,
45+ embed_method : str = "auto" ,
3946 batch_size : int = 32 ,
4047 save_probs_path : str = "tsds_probs.npy" ,
4148 max_K : int = 5000 ,
@@ -47,6 +54,7 @@ def __init__(self,
4754 self .candidate_path = candidate_path
4855 self .query_path = query_path
4956 self .embed_model = embed_model
57+ self .embed_method = embed_method
5058 self .batch_size = batch_size
5159 self .save_probs_path = save_probs_path
5260 self .max_K = max_K
@@ -72,28 +80,62 @@ def _load_alpaca_json(self, path):
7280
7381 # ---------- Embedding 方法 ----------
7482 def _embed_texts (self , texts ):
75- if VLLM_AVAILABLE and self .embed_model .startswith ("vllm:" ):
76- model_name = self .embed_model .replace ("vllm:" , "" )
77- logger .info (f"[EMBED] vLLM model: { model_name } " )
78- llm = LLM (model = model_name , trust_remote_code = True , task = "embed" )
79-
80- # 使用 vLLM 的 embed 接口
81- outputs = llm .embed (texts ) # 返回 [N, D]
82- print (f"Embeddings shape: { np .array (outputs ).shape } " , outputs [0 ])
83- embs = [o .outputs .embedding for o in outputs ]
84- embs = np .array (embs , dtype = np .float32 )
85- else :
86- logger .info (f"[EMBED] SentenceTransformer: { self .embed_model } " )
87- model = SentenceTransformer (self .embed_model )
88- embs = model .encode (texts ,
89- batch_size = self .batch_size ,
90- show_progress_bar = True ).astype (np .float32 )
91- norms = np .linalg .norm (embs , axis = 1 , keepdims = True ) # [N, 1]
92- # 防止除以 0
93- norms = np .maximum (norms , 1e-12 )
94- embs = embs / norms
95- # --------------------------------------
96- return np .ascontiguousarray (embs )
83+ '''
84+ auto模式自动尝试 embedding 后端:
85+ 1) 优先 vLLM
86+ 2) 否则 sentence-transformers
87+ 3) 都不可用则报错
88+ '''
89+
90+ # -------- 1. 优先 vLLM --------
91+ if (VLLM_AVAILABLE and self .embed_method == "auto" ) or self .embed_method == "vllm" :
92+ try :
93+ logger .info (f"[EMBED] Using vLLM model: { self .embed_model } " )
94+ llm = LLM (model = self .embed_model , trust_remote_code = True , task = "embed" )
95+
96+ outputs = llm .embed (texts ) # [N, D]
97+ embs = [o .outputs .embedding for o in outputs ]
98+ embs = np .array (embs , dtype = np .float32 )
99+
100+ # normalize
101+ norms = np .linalg .norm (embs , axis = 1 , keepdims = True )
102+ norms = np .maximum (norms , 1e-12 )
103+ embs = embs / norms
104+
105+ return np .ascontiguousarray (embs )
106+
107+ except Exception as e :
108+ logger .warning (f"[EMBED] vLLM available but embedding failed { e } " )
109+
110+ # -------- 2. fallback: sentence-transformers --------
111+ if (ST_AVAILABLE and self .embed_method == "auto" ) or self .embed_method == "sentence-transformer" :
112+ try :
113+ logger .info (f"[EMBED] Using SentenceTransformer: { self .embed_model } " )
114+ model = SentenceTransformer (self .embed_model )
115+ embs = model .encode (
116+ texts ,
117+ batch_size = self .batch_size ,
118+ show_progress_bar = True
119+ ).astype (np .float32 )
120+
121+ norms = np .linalg .norm (embs , axis = 1 , keepdims = True )
122+ norms = np .maximum (norms , 1e-12 )
123+ embs = embs / norms
124+
125+ return np .ascontiguousarray (embs )
126+
127+ except Exception as e :
128+ raise RuntimeError (
129+ f"SentenceTransformer available but embedding failed: { e } "
130+ )
131+
132+ # -------- 3. 两个都不可用 --------
133+ raise RuntimeError (
134+ "No available embedding backend!\n "
135+ "Please install at least one of the following:\n "
136+ " - vLLM: pip install vllm\n "
137+ " - sentence-transformers: pip install sentence-transformers"
138+ )
97139
98140 # ---------- TSDS 调用接口 ----------
99141 def candidate_sentence_embedding (self ):
@@ -192,10 +234,12 @@ def selector(self):
192234 tsds = offline_tsds_Selector (
193235 candidate_path = "OpenDCAI/DataFlex-selector-openhermes-10w" ,
194236 query_path = "OpenDCAI/DataFlex-selector-openhermes-10w" ,
195-
196- # If you want to use vllm,please add "vllm:" before model's name
197- # Otherwise it automatically use sentence-transfromer
198- embed_model = "vllm:Qwen/Qwen3-Embedding-0.6B" ,
237+ embed_model = "Qwen/Qwen3-Embedding-0.6B" ,
238+ # support method:
239+ #auto(It automatically try vllm first, then sentence-transformers),
240+ #vllm,
241+ #sentence-transformer
242+ embed_method = "auto" ,
199243 batch_size = 32 ,
200244 save_probs_path = "tsds_probs.npy" ,
201245 max_K = 5000 ,
0 commit comments