Skip to content

Commit bb81ae5

Browse files
author
Erick Friis
authored
together: fix chat model and embedding classes (#21353)
1 parent d6ef5fe commit bb81ae5

3 files changed

Lines changed: 16 additions & 12 deletions

File tree

libs/partners/together/langchain_together/chat_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _llm_type(self) -> str:
5959
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
6060
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
6161
together_api_base: Optional[str] = Field(
62-
default="https://api.together.ai/v1/chat/completions", alias="base_url"
62+
default="https://api.together.ai/v1/", alias="base_url"
6363
)
6464

6565
@root_validator()

libs/partners/together/langchain_together/embeddings.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
5151
client: Any = Field(default=None, exclude=True) #: :meta private:
5252
async_client: Any = Field(default=None, exclude=True) #: :meta private:
5353
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
54-
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
54+
"""Embeddings model name to use.
5555
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
5656
"""
5757
dimensions: Optional[int] = None
@@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
6262
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
6363
"""API Key for Solar API."""
6464
together_api_base: str = Field(
65-
default="https://api.together.ai/v1/embeddings", alias="base_url"
65+
default="https://api.together.ai/v1/", alias="base_url"
6666
)
6767
"""Endpoint URL to use."""
6868
embedding_ctx_length: int = 4096
@@ -166,21 +166,25 @@ def validate_environment(cls, values: Dict) -> Dict:
166166
"default_query": values["default_query"],
167167
}
168168
if not values.get("client"):
169-
sync_specific = {"http_client": values["http_client"]}
169+
sync_specific = (
170+
{"http_client": values["http_client"]} if values["http_client"] else {}
171+
)
170172
values["client"] = openai.OpenAI(
171173
**client_params, **sync_specific
172174
).embeddings
173175
if not values.get("async_client"):
174-
async_specific = {"http_client": values["http_async_client"]}
176+
async_specific = (
177+
{"http_client": values["http_async_client"]}
178+
if values["http_async_client"]
179+
else {}
180+
)
175181
values["async_client"] = openai.AsyncOpenAI(
176182
**client_params, **async_specific
177183
).embeddings
178184
return values
179185

180186
@property
181187
def _invocation_params(self) -> Dict[str, Any]:
182-
self.model = self.model.replace("-query", "").replace("-passage", "")
183-
184188
params: Dict = {"model": self.model, **self.model_kwargs}
185189
if self.dimensions is not None:
186190
params["dimensions"] = self.dimensions
@@ -197,7 +201,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
197201
"""
198202
embeddings = []
199203
params = self._invocation_params
200-
params["model"] = params["model"] + "-passage"
204+
params["model"] = params["model"]
201205

202206
for text in texts:
203207
response = self.client.create(input=text, **params)
@@ -217,7 +221,7 @@ def embed_query(self, text: str) -> List[float]:
217221
Embedding for the text.
218222
"""
219223
params = self._invocation_params
220-
params["model"] = params["model"] + "-query"
224+
params["model"] = params["model"]
221225

222226
response = self.client.create(input=text, **params)
223227

@@ -236,7 +240,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
236240
"""
237241
embeddings = []
238242
params = self._invocation_params
239-
params["model"] = params["model"] + "-passage"
243+
params["model"] = params["model"]
240244

241245
for text in texts:
242246
response = await self.async_client.create(input=text, **params)
@@ -256,7 +260,7 @@ async def aembed_query(self, text: str) -> List[float]:
256260
Embedding for the text.
257261
"""
258262
params = self._invocation_params
259-
params["model"] = params["model"] + "-query"
263+
params["model"] = params["model"]
260264

261265
response = await self.async_client.create(input=text, **params)
262266

libs/partners/together/tests/integration_tests/test_chat_models_standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ def chat_model_class(self) -> Type[BaseChatModel]:
1717
@pytest.fixture
1818
def chat_model_params(self) -> dict:
1919
return {
20-
"model": "meta-llama/Llama-3-8b-chat-hf",
20+
"model": "mistralai/Mistral-7B-Instruct-v0.1",
2121
}

0 commit comments

Comments
 (0)