77from typing import List , Optional , Union
88
99from langchain_community .embeddings import OpenAIEmbeddings
10+ from langchain_community .embeddings .openai import async_embed_with_retry
1011
1112from comps import (
1213 CustomLogger ,
@@ -35,7 +36,7 @@ async def _aget_len_safe_embeddings(
3536 ) -> List [List [float ]]:
3637 _chunk_size = chunk_size or self .chunk_size
3738 batched_embeddings : List [List [float ]] = []
38- response = self . client . create ( input = texts , ** self ._invocation_params )
39+ response = await async_embed_with_retry ( self , input = texts , ** self ._invocation_params )
3940 if not isinstance (response , dict ):
4041 response = response .model_dump ()
4142 batched_embeddings .extend (r ["embedding" ] for r in response ["data" ])
@@ -45,7 +46,7 @@ async def _aget_len_safe_embeddings(
4546 async def empty_embedding () -> List [float ]:
4647 nonlocal _cached_empty_embedding
4748 if _cached_empty_embedding is None :
48- average_embedded = self . client . create ( input = "" , ** self ._invocation_params )
49+ average_embedded = await async_embed_with_retry ( self , input = "" , ** self ._invocation_params )
4950 if not isinstance (average_embedded , dict ):
5051 average_embedded = average_embedded .model_dump ()
5152 _cached_empty_embedding = average_embedded ["data" ][0 ]["embedding" ]
@@ -57,6 +58,29 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
5758 embeddings = await asyncio .gather (* [get_embedding (e ) for e in batched_embeddings ])
5859 return embeddings
5960
61+ def _get_len_safe_embeddings (
62+ self , texts : List [str ], * , engine : str , chunk_size : Optional [int ] = None
63+ ) -> List [List [float ]]:
64+ _chunk_size = chunk_size or self .chunk_size
65+ batched_embeddings : List [List [float ]] = []
66+ response = self .client .create (input = texts , ** self ._invocation_params )
67+ if not isinstance (response , dict ):
68+ response = response .model_dump ()
69+ batched_embeddings .extend (r ["embedding" ] for r in response ["data" ])
70+
71+ _cached_empty_embedding : Optional [List [float ]] = None
72+
73+ def empty_embedding () -> List [float ]:
74+ nonlocal _cached_empty_embedding
75+ if _cached_empty_embedding is None :
76+ average_embedded = self .client .create (input = "" , ** self ._invocation_params )
77+ if not isinstance (average_embedded , dict ):
78+ average_embedded = average_embedded .model_dump ()
79+ _cached_empty_embedding = average_embedded ["data" ][0 ]["embedding" ]
80+ return _cached_empty_embedding
81+
82+ return [e if e is not None else empty_embedding () for e in batched_embeddings ]
83+
6084
6185@register_microservice (
6286 name = "opea_service@embedding_mosec" ,
@@ -68,18 +92,18 @@ async def get_embedding(e: Optional[List[float]]) -> List[float]:
6892 output_datatype = EmbedDoc ,
6993)
7094@register_statistics (names = ["opea_service@embedding_mosec" ])
71- async def embedding (
95+ def embedding (
7296 input : Union [TextDoc , EmbeddingRequest , ChatCompletionRequest ]
7397) -> Union [EmbedDoc , EmbeddingResponse , ChatCompletionRequest ]:
7498 if logflag :
7599 logger .info (input )
76100 start = time .time ()
77101 if isinstance (input , TextDoc ):
78- embed_vector = await get_embeddings (input .text )
102+ embed_vector = get_embeddings (input .text )
79103 embedding_res = embed_vector [0 ] if isinstance (input .text , str ) else embed_vector
80104 res = EmbedDoc (text = input .text , embedding = embedding_res )
81105 else :
82- embed_vector = await get_embeddings (input .input )
106+ embed_vector = get_embeddings (input .input )
83107 if input .dimensions is not None :
84108 embed_vector = [embed_vector [i ][: input .dimensions ] for i in range (len (embed_vector ))]
85109
@@ -99,9 +123,9 @@ async def embedding(
99123 return res
100124
101125
102- async def get_embeddings (text : Union [str , List [str ]]) -> List [List [float ]]:
126+ def get_embeddings (text : Union [str , List [str ]]) -> List [List [float ]]:
103127 texts = [text ] if isinstance (text , str ) else text
104- embed_vector = await embeddings .aembed_documents (texts )
128+ embed_vector = embeddings .embed_documents (texts )
105129 return embed_vector
106130
107131
0 commit comments