-
Notifications
You must be signed in to change notification settings - Fork 364
support qwen3-reranker #732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lvjg
wants to merge
3
commits into
LazyAGI:main
Choose a base branch
from
lvjg:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,7 +1,9 @@ | ||||
| import importlib.util | ||||
|
|
||||
| import re | ||||
| import requests | ||||
| from functools import lru_cache | ||||
| from typing import Callable, List, Optional, Union | ||||
| from typing import Callable, List, Dict, Optional, Union, Any | ||||
|
|
||||
| import lazyllm | ||||
| from lazyllm.thirdparty import spacy | ||||
|
|
@@ -91,6 +93,7 @@ def KeywordFilter(node: DocNode, required_keys: Optional[List[str]] = None, excl | |||
| return None | ||||
| return node | ||||
|
|
||||
|
|
||||
| @Reranker.register_reranker() | ||||
| class ModuleReranker(Reranker): | ||||
|
|
||||
|
|
@@ -116,6 +119,303 @@ def forward(self, nodes: List[DocNode], query: str = "") -> List[DocNode]: | |||
| LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}") | ||||
| return self._post_process(results) | ||||
|
|
||||
|
|
||||
| @Reranker.register_reranker() | ||||
| class UrlReranker(Reranker): | ||||
| """ | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 文档注释需要写到这里:lazyllm/docs/module.py |
||||
| 通用 HTTP 重排序器。 | ||||
|
|
||||
| 通过将 query 与一批候选文本打包为 JSON 请求发送到远端 URL, | ||||
| 解析返回的分数后对节点进行重排。 | ||||
|
|
||||
| 远端服务期望的响应格式(默认): | ||||
| List[{"index": int, "score": float}] | ||||
| 其中 "index" 为该批次内文档的局部索引(从 0 开始),"score" 为相关性分数。 | ||||
| """ | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| name: str = "UrlReranker", | ||||
| url: Optional[str] = None, | ||||
| api_key: str = "api_key", | ||||
| batch_size: int = 64, | ||||
| truncate_text: bool = True, | ||||
| output_format: Optional[str] = None, | ||||
| join: Union[bool, str] = False, | ||||
| timeout: Optional[float] = None, | ||||
| **kwargs: Any, | ||||
| ) -> None: | ||||
| """ | ||||
| Args: | ||||
| name: 重排序器名称。 | ||||
| url: 远端重排序服务地址(必填)。 | ||||
| api_key: 认证密钥(将置于 HTTP Bearer 头)。 | ||||
| batch_size: 批大小(原 rerank_batch_size)。 | ||||
| truncate_text: 是否在远端对文本进行截断。 | ||||
| output_format, join, **kwargs: 继承自 Reranker 的可选参数。 | ||||
| request_timeout: 请求超时时间,缺省为 DEFAULT_TIMEOUT。 | ||||
| """ | ||||
| super().__init__(name=name, output_format=output_format, join=join, **kwargs) | ||||
| if not url: | ||||
| raise ValueError("`url` 不能为空,请传入远端重排序服务地址。") | ||||
|
|
||||
| self._url = url | ||||
| self._api_key = api_key | ||||
| self._batch_size = max(1, int(batch_size)) | ||||
| self._truncate_text = bool(truncate_text) | ||||
| self._timeout = timeout | ||||
|
|
||||
| self._headers: Dict[str, str] = self._build_headers() | ||||
| self._session = requests.Session() | ||||
|
|
||||
| def _build_headers(self) -> Dict[str, str]: | ||||
| """构建 HTTP 头。""" | ||||
| return { | ||||
| "Content-Type": "application/json", | ||||
| "Authorization": f"Bearer {self._api_key}", | ||||
| } | ||||
|
|
||||
| def _extract_top_k(self, total: int, **kwargs: Any) -> int: | ||||
| """从 kwargs 中解析 top_k/topk,默认取全部。""" | ||||
| top_k = kwargs.get("top_k", kwargs.get("topk", total)) | ||||
| try: | ||||
| top_k = int(top_k) | ||||
| except Exception: | ||||
| top_k = total | ||||
| return max(0, min(top_k, total)) | ||||
|
|
||||
| def _get_format_content(self, nodes: List[DocNode], **kwargs: Any) -> List[str]: | ||||
| """ | ||||
| 生成待重排的文本列表。 | ||||
|
|
||||
| 若提供 template(如: "标题:{title}\n正文:{text}"),将按节点 metadata 与 text 替换。 | ||||
| 支持的占位符来源: | ||||
| - {text}: 节点正文 | ||||
| - {<metadata_key>}: 节点 metadata 中的键 | ||||
| 若占位符缺失对应值,则回退为空串。 | ||||
| """ | ||||
| template: Optional[str] = dict(kwargs).pop("template", None) | ||||
| if not template: | ||||
| return [n.get_text(metadata_mode=MetadataMode.EMBED) for n in nodes] | ||||
|
|
||||
| placeholders = re.findall(r"{(\w+)}", template) | ||||
|
|
||||
| formatted: List[str] = [] | ||||
| for node in nodes: | ||||
| values = { | ||||
| key: node.text if key == 'text' else node.metadata.get(key, "") for key in placeholders | ||||
| } | ||||
| try: | ||||
| formatted.append(template.format(**values)) | ||||
| except Exception as exc: | ||||
| LOG.warning("Template formatting failed; fallback to raw text: %s", exc) | ||||
| formatted.append(node.get_text(metadata_mode=MetadataMode.EMBED)) | ||||
| return formatted | ||||
|
|
||||
| def _encapsulated_data(self, query: str, texts: List[str], **kwargs: Any) -> Dict[str, Any]: | ||||
| """ | ||||
| 封装请求体。子类可重写。 | ||||
| 默认字段: | ||||
| { | ||||
| "query": "<用户查询>", | ||||
| "texts": ["doc1", "doc2", ...], | ||||
| "truncate": bool | ||||
| } | ||||
| """ | ||||
| payload: Dict[str, Any] = { | ||||
| "query": query, | ||||
| "texts": list(texts), | ||||
| "truncate": self._truncate_text, | ||||
| } | ||||
| if kwargs: | ||||
| for k, v in kwargs.items(): | ||||
| if k not in ("query", "texts", "truncate"): | ||||
| payload[k] = v | ||||
| return payload | ||||
|
|
||||
| def _parse_response(self, response: Any) -> List[float]: | ||||
| """ | ||||
| 解析远端返回为分数列表。子类可重写。 | ||||
|
|
||||
| 期望输入:List[{"index": int, "score": float}] | ||||
| 输出顺序:按 "index" 排序返回分数列表。 | ||||
| """ | ||||
| if not isinstance(response, list): | ||||
| LOG.warning("Response is not a list; attempting lenient parsing: %r", response) | ||||
| return [] | ||||
|
|
||||
| try: | ||||
| sorted_data = sorted(response, key=lambda x: x["index"]) | ||||
| return [float(item["score"]) for item in sorted_data] | ||||
| except Exception as exc: | ||||
| LOG.error("Failed to parse response: %s; response=%r", exc, response) | ||||
| return [] | ||||
|
|
||||
| def forward(self, nodes: List[DocNode], query: str, **kwargs: Any) -> List[DocNode]: | ||||
| """ | ||||
| 对候选节点进行重排并返回 Top-K(若未指定则返回全部)。 | ||||
| """ | ||||
| if not nodes: | ||||
| return [] | ||||
|
|
||||
| texts = self._get_format_content(nodes, **kwargs) | ||||
| top_k = self._extract_top_k(len(texts), **kwargs) | ||||
|
|
||||
| all_scores: List[float] = [] | ||||
| for start in range(0, len(texts), self._batch_size): | ||||
| batch_texts = texts[start : start + self._batch_size] | ||||
| payload = self._encapsulated_data(query, batch_texts, **kwargs) | ||||
|
|
||||
| try: | ||||
| resp = self._session.post( | ||||
| self._url, json=payload, headers=self._headers, timeout=self._timeout | ||||
| ) | ||||
| resp.raise_for_status() | ||||
| scores = self._parse_response(resp.json()) | ||||
| except requests.RequestException as exc: | ||||
| LOG.error("HTTP request for reranking failed (this batch will be scored as 0): %s", exc) | ||||
| scores = [] | ||||
|
|
||||
| if len(scores) != len(batch_texts): | ||||
| LOG.warning( | ||||
| "Returned scores count mismatches inputs: got=%d, expected=%d; padding with zeros.", | ||||
| len(scores), len(batch_texts) | ||||
| ) | ||||
| if len(scores) < len(batch_texts): | ||||
| scores += [0.0] * (len(batch_texts) - len(scores)) | ||||
| else: | ||||
| scores = scores[: len(batch_texts)] | ||||
|
|
||||
| all_scores.extend(scores) | ||||
|
|
||||
| scored_nodes: List[DocNode] = [ | ||||
| nodes[i].with_score(all_scores[i]) for i in range(len(nodes)) | ||||
| ] | ||||
|
|
||||
| scored_nodes.sort(key=lambda n: n.relevance_score, reverse=True) | ||||
| results = scored_nodes[:top_k] if top_k > 0 else scored_nodes | ||||
| LOG.debug(f"Rerank use `{self._name}` and get nodes: {results}") | ||||
| return self._post_process(results) | ||||
|
|
||||
|
|
||||
| @Reranker.register_reranker() | ||||
| class Qwen3Reranker(UrlReranker): | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议参考这个来做:
|
||||
| """ | ||||
| 基于 Qwen3 样式 Prompt/响应协议的重排序器。 | ||||
| 请求体: | ||||
| { | ||||
| "query": "<拼装后的系统指令+用户查询>", | ||||
| "documents": ["<每个 doc 的拼装文本>", ...], | ||||
| ... # 其他可选字段 | ||||
| } | ||||
| 响应体(期望): | ||||
| { | ||||
| "results": [ | ||||
| {"index": int, "relevance_score": float}, | ||||
| ... | ||||
| ] | ||||
| } | ||||
| """ | ||||
|
|
||||
| _PROMPT_PREFIX = ( | ||||
| "<|im_start|>system\n" | ||||
| "Judge whether the Document meets the requirements based on the Query and the Instruct provided. " | ||||
| 'Note that the answer can only be "yes" or "no".' | ||||
| "<|im_end|>\n<|im_start|>user\n" | ||||
| ) | ||||
| _PROMPT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | ||||
|
|
||||
| _QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" | ||||
| _DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}" | ||||
|
|
||||
| _LOCAL_TRUNCATE_MAX_CHARS = 16384 | ||||
| _DEFAULT_TASK_DESCRIPTION = "Given a web search query, retrieve relevant passages that answer the query" | ||||
|
|
||||
| def __init__( | ||||
| self, | ||||
| name: str = "Qwen3Reranker", | ||||
| url: Optional[str] = None, | ||||
| api_key: str = "api_key", | ||||
| batch_size: int = 64, | ||||
| truncate_text: bool = True, | ||||
| output_format: Optional[str] = None, | ||||
| join: Union[bool, str] = False, | ||||
| task_description: Optional[str] = None, | ||||
| request_timeout: Optional[float] = None, | ||||
| **kwargs: Any, | ||||
| ) -> None: | ||||
| """ | ||||
| Args: | ||||
| task_description: 任务描述,会被拼入 system/user 区块。 | ||||
| """ | ||||
| super().__init__( | ||||
| name=name, | ||||
| url=url, | ||||
| api_key=api_key, | ||||
| batch_size=batch_size, | ||||
| truncate_text=truncate_text, | ||||
| output_format=output_format, | ||||
| join=join, | ||||
| request_timeout=request_timeout, | ||||
| **kwargs, | ||||
| ) | ||||
| self._task_description = task_description or self._DEFAULT_TASK_DESCRIPTION | ||||
|
|
||||
| def _build_instruct(self, task_description: str, query: str) -> str: | ||||
| """拼装包含系统前缀与用户区块的 query 字符串。""" | ||||
| return self._QUERY_TEMPLATE.format( | ||||
| prefix=self._PROMPT_PREFIX, instruction=task_description, query=query | ||||
| ) | ||||
|
|
||||
| def _build_documents(self, texts: List[str]) -> List[str]: | ||||
| """ | ||||
| 将每条文本套入文档模板;若开启 truncate,则在这里进行**本地字符级截断**。 | ||||
| - 截断阈值:_LOCAL_TRUNCATE_MAX_CHARS | ||||
| - 仅当 self._truncate_text 为 True 时生效 | ||||
| """ | ||||
| docs: List[str] = [] | ||||
|
|
||||
| def _truncate_if_needed(s: str) -> str: | ||||
| if not self._truncate_text: | ||||
| return s | ||||
| if len(s) <= self._LOCAL_TRUNCATE_MAX_CHARS: | ||||
| return s | ||||
| return s[: self._LOCAL_TRUNCATE_MAX_CHARS] | ||||
|
|
||||
| for t in texts: | ||||
| t_norm = _truncate_if_needed(t or "") | ||||
| docs.append(self._DOCUMENT_TEMPLATE.format(doc=t_norm, suffix=self._PROMPT_SUFFIX)) | ||||
| return docs | ||||
|
|
||||
| def _encapsulated_data(self, query: str, texts: List[str], **kwargs: Any) -> Dict[str, Any]: | ||||
| payload: Dict[str, Any] = { | ||||
| "query": self._build_instruct(self._task_description, query), | ||||
| "documents": self._build_documents(texts), | ||||
| } | ||||
| if kwargs: | ||||
| for k, v in kwargs.items(): | ||||
| if k not in ("query", "documents"): | ||||
| payload[k] = v | ||||
| return payload | ||||
|
|
||||
| def _parse_response(self, response: Any) -> List[float]: | ||||
| """ | ||||
| 期望输入: | ||||
| {"results": [{"index": int, "relevance_score": float}, ...]} | ||||
| """ | ||||
| if not isinstance(response, dict) or "results" not in response: | ||||
| LOG.warning("response missing 'results' field: %r", response) | ||||
| return [] | ||||
|
|
||||
| results = response.get("results", []) | ||||
| try: | ||||
| results = sorted(results, key=lambda x: x["index"]) | ||||
| return [float(item["relevance_score"]) for item in results] | ||||
| except Exception as exc: | ||||
| LOG.error("Failed to parse response: %s; response=%r", exc, response) | ||||
| return [] | ||||
|
|
||||
|
|
||||
| # User-defined similarity decorator | ||||
| def register_reranker(func=None, batch=False): | ||||
| return Reranker.register_reranker(func, batch) | ||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
都迁移到这里吧:
lazyllm/module/llms/onlinemodule/supplier,作为Online的,不要作为local的。