Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pymilvus/model/reranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pymilvus.model.reranker.cross_encoder import CrossEncoderRerankFunction
from pymilvus.model.reranker.jinaai import JinaRerankFunction
from pymilvus.model.reranker.tei import TEIRerankFunction
from pymilvus.model.reranker.dashscope import DashscopeRerankFunction

__all__ = [
"CohereRerankFunction",
Expand All @@ -12,4 +13,5 @@
"CrossEncoderRerankFunction",
"JinaRerankFunction",
"TEIRerankFunction",
"DashscopeRerankFunction",
]
62 changes: 62 additions & 0 deletions src/pymilvus/model/reranker/dashscope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
from typing import List, Optional

import requests

from pymilvus.model.base import BaseRerankFunction, RerankResult

API_URL = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"


class DashscopeRerankFunction(BaseRerankFunction):
def __init__(self, model_name: str = "qwen3-rerank", api_key: Optional[str] = None, **kwargs):
if api_key is None:
if "DASHSCOPE_API_KEY" in os.environ and os.environ["DASHSCOPE_API_KEY"]:
self.api_key = os.environ["DASHSCOPE_API_KEY"]
else:
error_message = (
"Did not find api_key, please add an environment variable"
" `DASHSCOPE_API_KEY` which contains it, or pass"
" `api_key` as a named parameter."
)
raise ValueError(error_message)
else:
self.api_key = api_key
self.model_name = model_name
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
self.model_name = model_name
self.rerank_config = {**kwargs}

def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
json_data = {
"model": self.model_name,
"input":{
"query": query,
"documents": documents,
},
}
if self.rerank_config:
json_data["parameters"] = self.rerank_config
else:
json_data["parameters"] = {}
json_data["parameters"]["top_n"] = top_k
resp = self._session.post( # type: ignore[assignment]
API_URL,
json=json_data,
).json()
if "output" not in resp:
raise RuntimeError(resp["output"])

results = []
for res in resp["output"]["results"]:
results.append(
RerankResult(
text=res.get("document", {}).get("text", ""),
score=res["relevance_score"],
index=res["index"]
)
)
return results