Skip to content

Commit fbf3017

Browse files
Revert mosec embedding microservice to to use synchronous interface. (opea-project#971)
* Revert mosec embedding microservice to to use synchronous interface. Signed-off-by: Yao, Qing <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add dependency. Signed-off-by: Yao, Qing <[email protected]> --------- Signed-off-by: Yao, Qing <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5663e16 commit fbf3017

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

comps/embeddings/mosec/langchain/embedding_mosec.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Optional, Union
88

99
from langchain_community.embeddings import OpenAIEmbeddings
10+
from langchain_community.embeddings.openai import async_embed_with_retry
1011

1112
from comps import (
1213
CustomLogger,
@@ -35,7 +36,7 @@ async def _aget_len_safe_embeddings(
3536
) -> List[List[float]]:
3637
_chunk_size = chunk_size or self.chunk_size
3738
batched_embeddings: List[List[float]] = []
38-
response = self.client.create(input=texts, **self._invocation_params)
39+
response = await async_embed_with_retry(self, input=texts, **self._invocation_params)
3940
if not isinstance(response, dict):
4041
response = response.model_dump()
4142
batched_embeddings.extend(r["embedding"] for r in response["data"])
@@ -45,7 +46,7 @@ async def _aget_len_safe_embeddings(
4546
async def empty_embedding() -> List[float]:
4647
nonlocal _cached_empty_embedding
4748
if _cached_empty_embedding is None:
48-
average_embedded = self.client.create(input="", **self._invocation_params)
49+
average_embedded = await async_embed_with_retry(self, input="", **self._invocation_params)
4950
if not isinstance(average_embedded, dict):
5051
average_embedded = average_embedded.model_dump()
5152
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
@@ -57,6 +58,29 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
5758
embeddings = await asyncio.gather(*[get_embedding(e) for e in batched_embeddings])
5859
return embeddings
5960

61+
def _get_len_safe_embeddings(
62+
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
63+
) -> List[List[float]]:
64+
_chunk_size = chunk_size or self.chunk_size
65+
batched_embeddings: List[List[float]] = []
66+
response = self.client.create(input=texts, **self._invocation_params)
67+
if not isinstance(response, dict):
68+
response = response.model_dump()
69+
batched_embeddings.extend(r["embedding"] for r in response["data"])
70+
71+
_cached_empty_embedding: Optional[List[float]] = None
72+
73+
def empty_embedding() -> List[float]:
74+
nonlocal _cached_empty_embedding
75+
if _cached_empty_embedding is None:
76+
average_embedded = self.client.create(input="", **self._invocation_params)
77+
if not isinstance(average_embedded, dict):
78+
average_embedded = average_embedded.model_dump()
79+
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
80+
return _cached_empty_embedding
81+
82+
return [e if e is not None else empty_embedding() for e in batched_embeddings]
83+
6084

6185
@register_microservice(
6286
name="opea_service@embedding_mosec",
@@ -68,18 +92,18 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
6892
output_datatype=EmbedDoc,
6993
)
7094
@register_statistics(names=["opea_service@embedding_mosec"])
71-
async def embedding(
95+
def embedding(
7296
input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest]
7397
) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]:
7498
if logflag:
7599
logger.info(input)
76100
start = time.time()
77101
if isinstance(input, TextDoc):
78-
embed_vector = await get_embeddings(input.text)
102+
embed_vector = get_embeddings(input.text)
79103
embedding_res = embed_vector[0] if isinstance(input.text, str) else embed_vector
80104
res = EmbedDoc(text=input.text, embedding=embedding_res)
81105
else:
82-
embed_vector = await get_embeddings(input.input)
106+
embed_vector = get_embeddings(input.input)
83107
if input.dimensions is not None:
84108
embed_vector = [embed_vector[i][: input.dimensions] for i in range(len(embed_vector))]
85109

@@ -99,9 +123,9 @@ async def embedding(
99123
return res
100124

101125

102-
async def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
126+
def get_embeddings(text: Union[str, List[str]]) -> List[List[float]]:
103127
texts = [text] if isinstance(text, str) else text
104-
embed_vector = await embeddings.aembed_documents(texts)
128+
embed_vector = embeddings.embed_documents(texts)
105129
return embed_vector
106130

107131

0 commit comments

Comments
 (0)