|
2 | 2 | from typing import Any, Dict, Optional |
3 | 3 | from urllib.parse import urljoin |
4 | 4 |
|
5 | | -from fastapi import Request |
6 | 5 | import httpx |
7 | 6 | from openai import AsyncOpenAI |
8 | 7 | import requests |
@@ -63,25 +62,25 @@ def __init__(self, model: str, api_url: str, api_key: str, timeout: int, *args, |
63 | 62 | else: |
64 | 63 | self.vector_size = None |
65 | 64 |
|
66 | | - def _format_request(self, request: Request, json: Optional[dict] = None, files: Optional[dict] = None, data: Optional[dict] = None) -> dict: |
| 65 | + def _format_request(self, json: Optional[dict] = None, files: Optional[dict] = None, data: Optional[dict] = None) -> dict: |
67 | 66 | """ |
68 | 67 | Format a request to a client model. Overridden base class method to support TEI Reranking. |
69 | 68 |
|
70 | 69 | Args: |
71 | | - endpoint(str): The endpoint to forward the request to. |
72 | 70 | json(dict): The JSON body to use for the request. |
73 | 71 | files(dict): The files to use for the request. |
74 | 72 | data(dict): The data to use for the request. |
75 | 73 |
|
76 | 74 | Returns: |
77 | 75 | tuple: The formatted request composed of the url, headers, json, files and data. |
78 | 76 | """ |
79 | | - url = urljoin(base=self.api_url, url=self.ENDPOINT_TABLE[request.url.path.removeprefix("/v1")]) |
| 77 | + # self.endpoint is set by the ModelRouter |
| 78 | + url = urljoin(base=self.api_url, url=self.ENDPOINT_TABLE[self.endpoint]) |
80 | 79 | headers = {"Authorization": f"Bearer {self.api_key}"} |
81 | 80 | if json and "model" in json: |
82 | 81 | json["model"] = self.model |
83 | 82 |
|
84 | | - if request.url.path.endswith(ENDPOINT__RERANK): |
| 83 | + if self.endpoint.endswith(ENDPOINT__RERANK): |
85 | 84 | json = {"query": json["prompt"], "texts": json["input"]} |
86 | 85 |
|
87 | 86 | return url, headers, json, files, data |
|
0 commit comments