@@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
5151 client : Any = Field (default = None , exclude = True ) #: :meta private:
5252 async_client : Any = Field (default = None , exclude = True ) #: :meta private:
5353 model : str = "togethercomputer/m2-bert-80M-8k-retrieval"
54- """Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
54+ """Embeddings model name to use.
5555 Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
5656 """
5757 dimensions : Optional [int ] = None
@@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
6262 together_api_key : Optional [SecretStr ] = Field (default = None , alias = "api_key" )
6363 """API Key for Solar API."""
6464 together_api_base : str = Field (
65- default = "https://api.together.ai/v1/embeddings " , alias = "base_url"
65+ default = "https://api.together.ai/v1/" , alias = "base_url"
6666 )
6767 """Endpoint URL to use."""
6868 embedding_ctx_length : int = 4096
@@ -166,21 +166,25 @@ def validate_environment(cls, values: Dict) -> Dict:
166166 "default_query" : values ["default_query" ],
167167 }
168168 if not values .get ("client" ):
169- sync_specific = {"http_client" : values ["http_client" ]}
169+ sync_specific = (
170+ {"http_client" : values ["http_client" ]} if values ["http_client" ] else {}
171+ )
170172 values ["client" ] = openai .OpenAI (
171173 ** client_params , ** sync_specific
172174 ).embeddings
173175 if not values .get ("async_client" ):
174- async_specific = {"http_client" : values ["http_async_client" ]}
176+ async_specific = (
177+ {"http_client" : values ["http_async_client" ]}
178+ if values ["http_async_client" ]
179+ else {}
180+ )
175181 values ["async_client" ] = openai .AsyncOpenAI (
176182 ** client_params , ** async_specific
177183 ).embeddings
178184 return values
179185
180186 @property
181187 def _invocation_params (self ) -> Dict [str , Any ]:
182- self .model = self .model .replace ("-query" , "" ).replace ("-passage" , "" )
183-
184188 params : Dict = {"model" : self .model , ** self .model_kwargs }
185189 if self .dimensions is not None :
186190 params ["dimensions" ] = self .dimensions
@@ -197,7 +201,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
197201 """
198202 embeddings = []
199203 params = self ._invocation_params
200- params ["model" ] = params ["model" ] + "-passage"
204+ params ["model" ] = params ["model" ]
201205
202206 for text in texts :
203207 response = self .client .create (input = text , ** params )
@@ -217,7 +221,7 @@ def embed_query(self, text: str) -> List[float]:
217221 Embedding for the text.
218222 """
219223 params = self ._invocation_params
220- params ["model" ] = params ["model" ] + "-query"
224+ params ["model" ] = params ["model" ]
221225
222226 response = self .client .create (input = text , ** params )
223227
@@ -236,7 +240,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
236240 """
237241 embeddings = []
238242 params = self ._invocation_params
239- params ["model" ] = params ["model" ] + "-passage"
243+ params ["model" ] = params ["model" ]
240244
241245 for text in texts :
242246 response = await self .async_client .create (input = text , ** params )
@@ -256,7 +260,7 @@ async def aembed_query(self, text: str) -> List[float]:
256260 Embedding for the text.
257261 """
258262 params = self ._invocation_params
259- params ["model" ] = params ["model" ] + "-query"
263+ params ["model" ] = params ["model" ]
260264
261265 response = await self .async_client .create (input = text , ** params )
262266
0 commit comments