1111 "cohere": "https://api.cohere.ai/v1/rerank",
1212 "jina": "https://api.jina.ai/v1/rerank",
1313 "voyage": "https://api.voyageai.com/v1/rerank",
14- "mixedbread": NotImplemented ,
14+ "mixedbread.ai ": "https://api.mixedbread.ai/v1/reranking" ,
1515}
1616
1717
@@ -29,24 +29,36 @@ def __init__(self, model: str, api_key: str, api_provider: str, verbose: int = 1
2929 }
3030 self.url = URLS[self.api_provider]
3131
32+ def _get_document_text(self, r: dict) -> str:
33+ if self.api_provider == "voyage":
34+ return r["document"]
35+ elif self.api_provider == "mixedbread.ai":
36+ return r["input"]
37+ else:
38+ return r["document"]["text"]
39+
40+ def _get_score(self, r: dict) -> float:
41+ if self.api_provider == "mixedbread.ai":
42+ return r["score"]
43+ return r["relevance_score"]
44+
3245 def _parse_response(
3346 self, response: dict, doc_ids: Union[List[str], List[int]]
3447 ) -> RankedResults:
3548 ranked_docs = []
36- results_key = "results" if self.api_provider != "voyage" else "data"
49+ results_key = (
50+ "results"
51+ if self.api_provider not in ["voyage", "mixedbread.ai"]
52+ else "data"
53+ )
3754 print(response)
3855
3956 for i, r in enumerate(response[results_key]):
40- document_text = (
41- r["document"]
42- if self.api_provider == "voyage"
43- else r["document"]["text"]
44- )
4557 ranked_docs.append(
4658 Result(
4759 doc_id=doc_ids[r["index"]],
48- text=document_text ,
49- score=r["relevance_score"] ,
60+ text=self._get_document_text(r) ,
61+ score=self._get_score(r) ,
5062 rank=i + 1,
5163 )
5264 )
@@ -67,13 +79,22 @@ def rank(
6779 return RankedResults(results=results, query=query, has_scores=True)
6880
6981 def _format_payload(self, query: str, docs: List[str]) -> str:
70- top_key = "top_n" if self.api_provider != "voyage" else "top_k"
82+ top_key = (
83+ "top_n" if self.api_provider not in ["voyage", "mixedbread.ai"] else "top_k"
84+ )
85+ documents_key = "documents" if self.api_provider != "mixedbread.ai" else "input"
86+ return_documents_key = (
87+ "return_documents"
88+ if self.api_provider != "mixedbread.ai"
89+ else "return_input"
90+ )
91+
7192 payload = {
7293 "model": self.model,
7394 "query": query,
74- "documents" : docs,
95+ documents_key : docs,
7596 top_key: len(docs),
76- "return_documents" : True,
97+ return_documents_key : True,
7798 }
7899 return json.dumps(payload)
79100
0 commit comments