Skip to content

Commit bf9f169

Browse files
support embedding models in model validation program (#1036)
* support embedding models in model validation program * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert modelcar config file changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove print statement Signed-off-by: Edward Arthur Quarm Jnr <equarmjn@redhat.com> * meaningful logging messages Signed-off-by: Edward Arthur Quarm Jnr <equarmjn@redhat.com> --------- Signed-off-by: Edward Arthur Quarm Jnr <equarmjn@redhat.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b7c9d66 commit bf9f169

3 files changed

Lines changed: 82 additions & 5 deletions

File tree

tests/model_serving/model_runtime/model_validation/constant.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@
7070
],
7171
]
7272

73+
EMBEDDING_QUERY: list[dict[str, str]] = [
74+
{
75+
"text": "What are the key benefits of renewable energy sources compared to fossil fuels?",
76+
},
77+
{"text": "Translate the following English sentence into Spanish, German, and Mandarin: 'Knowledge is power.'"},
78+
{"text": "Write a poem about the beauty of the night sky and the mysteries it holds."},
79+
{"text": "Explain the significance of the Great Wall of China in history and its impact on modern tourism."},
80+
{"text": "Discuss the ethical implications of using artificial intelligence in healthcare decision-making."},
81+
{
82+
"text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E122, E501
83+
},
84+
]
85+
7386
PULL_SECRET_ACCESS_TYPE: str = '["Pull"]'
7487
PULL_SECRET_NAME: str = "oci-registry-pull-secret"
7588
SPYRE_INFERENCE_SERVICE_PORT: int = 8000

tests/model_serving/model_runtime/utils.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from tests.model_serving.model_runtime.model_validation.constant import (
99
COMPLETION_QUERY,
10+
EMBEDDING_QUERY,
1011
OPENAI_ENDPOINT_NAME,
1112
AUDIO_FILE_URL,
1213
AUDIO_FILE_LOCAL_PATH,
@@ -35,6 +36,13 @@ def validate_audio_inference_output(model_info: Any, completion_responses: Itera
3536
assert len(completion_responses) > 0, "Completion responses should not be empty"
3637

3738

39+
def validate_embedding_inference_output(model_info: Any, embedding_responses: Iterable[Any]) -> None:
40+
assert model_info is not None, "Model info should not be None"
41+
assert isinstance(model_info, (list, tuple)), "Model info should be a list or tuple"
42+
assert isinstance(embedding_responses, (list, tuple)), "Embedding responses should be a list or tuple"
43+
assert len(embedding_responses) > 0, "Embedding responses should not be empty"
44+
45+
3846
def fetch_tgis_response( # type: ignore
3947
url: str,
4048
model_name: str,
@@ -87,6 +95,54 @@ def run_raw_inference(
8795
raise NotSupportedError(f"{endpoint} endpoint")
8896

8997

98+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=6))
99+
def run_embedding_inference(
100+
endpoint: str,
101+
model_name: str,
102+
url: Optional[str] = None,
103+
pod_name: Optional[str] = None,
104+
isvc: Optional[InferenceService] = None,
105+
port: Optional[int] = Ports.REST_PORT,
106+
embedding_query: list[dict[str, str]] = EMBEDDING_QUERY,
107+
) -> tuple[Any, list[Any]]:
108+
LOGGER.info("Running embedding inference for model: %s on endpoint: %s", model_name, endpoint)
109+
if url is not None:
110+
LOGGER.info("Using provided URL for inference: %s", url)
111+
inference_client = OpenAIClient(host=url, model_name=model_name, streaming=True)
112+
embedding_responses = []
113+
for query in embedding_query:
114+
embedding_response = inference_client.request_http(
115+
endpoint=OpenAIEnpoints.EMBEDDINGS,
116+
query=query,
117+
)
118+
embedding_responses.append(embedding_response)
119+
model_info = OpenAIClient.get_request_http(host=url, endpoint=OpenAIEnpoints.MODELS_INFO)
120+
return model_info, embedding_responses
121+
else:
122+
LOGGER.info("Using port forwarding for inference on pod: %s", pod_name)
123+
if pod_name is None or isvc is None or port is None:
124+
raise ValueError("pod_name, isvc, and port are required when url is not provided")
125+
126+
with portforward.forward(
127+
pod_or_service=pod_name,
128+
namespace=isvc.namespace,
129+
from_port=port,
130+
to_port=port,
131+
):
132+
if endpoint == "openai":
133+
embedding_responses = []
134+
inference_client = OpenAIClient(host=f"http://localhost:{port}", model_name=model_name, streaming=True)
135+
for query in embedding_query:
136+
embedding_response = inference_client.request_http(endpoint=OpenAIEnpoints.EMBEDDINGS, query=query)
137+
embedding_responses.append(embedding_response)
138+
model_info = OpenAIClient.get_request_http(
139+
host=f"http://localhost:{port}", endpoint=OpenAIEnpoints.MODELS_INFO
140+
)
141+
return model_info, embedding_responses
142+
else:
143+
raise NotSupportedError(f"{endpoint} endpoint for embedding inference")
144+
145+
90146
@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=6))
91147
def run_audio_inference(
92148
endpoint: str,
@@ -98,7 +154,7 @@ def run_audio_inference(
98154
isvc: Optional[InferenceService] = None,
99155
port: Optional[int] = Ports.REST_PORT,
100156
) -> tuple[Any, list[Any]]:
101-
LOGGER.info(pod_name)
157+
LOGGER.info("Running audio inference for model: %s on endpoint: %s", model_name, endpoint)
102158
download_audio_file(audio_file_url=audio_file_url, destination_path=audio_file_path)
103159

104160
if url is not None:
@@ -177,6 +233,16 @@ def validate_raw_openai_inference_request(
177233
completion_responses,
178234
response_snapshot=response_snapshot,
179235
)
236+
elif model_output_type == "embedding":
237+
model_info, embedding_responses = run_embedding_inference(
238+
pod_name=pod_name,
239+
isvc=isvc,
240+
port=port,
241+
endpoint=OPENAI_ENDPOINT_NAME,
242+
embedding_query=EMBEDDING_QUERY,
243+
model_name=model_name,
244+
)
245+
validate_embedding_inference_output(model_info=model_info, embedding_responses=embedding_responses)
180246

181247
else:
182248
raise NotSupportedError(f"Model output type {model_output_type} is not supported for raw inference request.")

utilities/plugins/openai_plugin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ def _construct_request_data(
201201
elif OpenAIEnpoints.EMBEDDINGS in endpoint:
202202
data = {
203203
"input": query["text"],
204-
"encoding_format": 0.1,
205-
"temperature": 0,
206204
}
207205
else:
208206
data = {"prompt": query["text"], "temperature": 0, "top_p": 0.9, "seed": 1037, "stream": streaming}
@@ -230,8 +228,8 @@ def _parse_response(self, endpoint: str, message: dict[str, Any]) -> Any:
230228
LOGGER.info(message["choices"][0])
231229
return message["choices"][0]
232230
elif OpenAIEnpoints.EMBEDDINGS in endpoint:
233-
LOGGER.info(message["choices"][0])
234-
return message["choices"][0]
231+
LOGGER.info(message["data"][0])
232+
return message["data"][0]
235233
elif OpenAIEnpoints.AUDIO_TRANSCRIPTION in endpoint:
236234
LOGGER.info(message["text"])
237235
return message["text"]

0 commit comments

Comments
 (0)