77
88from tests .model_serving .model_runtime .model_validation .constant import (
99 COMPLETION_QUERY ,
10+ EMBEDDING_QUERY ,
1011 OPENAI_ENDPOINT_NAME ,
1112 AUDIO_FILE_URL ,
1213 AUDIO_FILE_LOCAL_PATH ,
@@ -35,6 +36,13 @@ def validate_audio_inference_output(model_info: Any, completion_responses: Itera
3536 assert len (completion_responses ) > 0 , "Completion responses should not be empty"
3637
3738
39+ def validate_embedding_inference_output (model_info : Any , embedding_responses : Iterable [Any ]) -> None :
40+ assert model_info is not None , "Model info should not be None"
41+ assert isinstance (model_info , (list , tuple )), "Model info should be a list or tuple"
42+ assert isinstance (embedding_responses , (list , tuple )), "Embedding responses should be a list or tuple"
43+ assert len (embedding_responses ) > 0 , "Embedding responses should not be empty"
44+
45+
3846def fetch_tgis_response ( # type: ignore
3947 url : str ,
4048 model_name : str ,
@@ -87,6 +95,54 @@ def run_raw_inference(
8795 raise NotSupportedError (f"{ endpoint } endpoint" )
8896
8997
98+ @retry (stop = stop_after_attempt (5 ), wait = wait_exponential (min = 1 , max = 6 ))
99+ def run_embedding_inference (
100+ endpoint : str ,
101+ model_name : str ,
102+ url : Optional [str ] = None ,
103+ pod_name : Optional [str ] = None ,
104+ isvc : Optional [InferenceService ] = None ,
105+ port : Optional [int ] = Ports .REST_PORT ,
106+ embedding_query : list [dict [str , str ]] = EMBEDDING_QUERY ,
107+ ) -> tuple [Any , list [Any ]]:
108+ LOGGER .info ("Running embedding inference for model: %s on endpoint: %s" , model_name , endpoint )
109+ if url is not None :
110+ LOGGER .info ("Using provided URL for inference: %s" , url )
111+ inference_client = OpenAIClient (host = url , model_name = model_name , streaming = True )
112+ embedding_responses = []
113+ for query in embedding_query :
114+ embedding_response = inference_client .request_http (
115+ endpoint = OpenAIEnpoints .EMBEDDINGS ,
116+ query = query ,
117+ )
118+ embedding_responses .append (embedding_response )
119+ model_info = OpenAIClient .get_request_http (host = url , endpoint = OpenAIEnpoints .MODELS_INFO )
120+ return model_info , embedding_responses
121+ else :
122+ LOGGER .info ("Using port forwarding for inference on pod: %s" , pod_name )
123+ if pod_name is None or isvc is None or port is None :
124+ raise ValueError ("pod_name, isvc, and port are required when url is not provided" )
125+
126+ with portforward .forward (
127+ pod_or_service = pod_name ,
128+ namespace = isvc .namespace ,
129+ from_port = port ,
130+ to_port = port ,
131+ ):
132+ if endpoint == "openai" :
133+ embedding_responses = []
134+ inference_client = OpenAIClient (host = f"http://localhost:{ port } " , model_name = model_name , streaming = True )
135+ for query in embedding_query :
136+ embedding_response = inference_client .request_http (endpoint = OpenAIEnpoints .EMBEDDINGS , query = query )
137+ embedding_responses .append (embedding_response )
138+ model_info = OpenAIClient .get_request_http (
139+ host = f"http://localhost:{ port } " , endpoint = OpenAIEnpoints .MODELS_INFO
140+ )
141+ return model_info , embedding_responses
142+ else :
143+ raise NotSupportedError (f"{ endpoint } endpoint for embedding inference" )
144+
145+
90146@retry (stop = stop_after_attempt (5 ), wait = wait_exponential (min = 1 , max = 6 ))
91147def run_audio_inference (
92148 endpoint : str ,
@@ -98,7 +154,7 @@ def run_audio_inference(
98154 isvc : Optional [InferenceService ] = None ,
99155 port : Optional [int ] = Ports .REST_PORT ,
100156) -> tuple [Any , list [Any ]]:
101- LOGGER .info (pod_name )
157+ LOGGER .info ("Running audio inference for model: %s on endpoint: %s" , model_name , endpoint )
102158 download_audio_file (audio_file_url = audio_file_url , destination_path = audio_file_path )
103159
104160 if url is not None :
@@ -177,6 +233,16 @@ def validate_raw_openai_inference_request(
177233 completion_responses ,
178234 response_snapshot = response_snapshot ,
179235 )
236+ elif model_output_type == "embedding" :
237+ model_info , embedding_responses = run_embedding_inference (
238+ pod_name = pod_name ,
239+ isvc = isvc ,
240+ port = port ,
241+ endpoint = OPENAI_ENDPOINT_NAME ,
242+ embedding_query = EMBEDDING_QUERY ,
243+ model_name = model_name ,
244+ )
245+ validate_embedding_inference_output (model_info = model_info , embedding_responses = embedding_responses )
180246
181247 else :
182248 raise NotSupportedError (f"Model output type { model_output_type } is not supported for raw inference request." )
0 commit comments