Skip to content
13 changes: 13 additions & 0 deletions tests/model_serving/model_runtime/model_validation/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@
],
]

EMBEDDING_QUERY: list[dict[str, str]] = [
{
"text": "What are the key benefits of renewable energy sources compared to fossil fuels?",
},
{"text": "Translate the following English sentence into Spanish, German, and Mandarin: 'Knowledge is power.'"},
{"text": "Write a poem about the beauty of the night sky and the mysteries it holds."},
{"text": "Explain the significance of the Great Wall of China in history and its impact on modern tourism."},
{"text": "Discuss the ethical implications of using artificial intelligence in healthcare decision-making."},
{
"text": "Summarize the main events of the Apollo 11 moon landing and its importance in space exploration history." # noqa: E122, E501
},
]

PULL_SECRET_ACCESS_TYPE: str = '["Pull"]'
PULL_SECRET_NAME: str = "oci-registry-pull-secret"
SPYRE_INFERENCE_SERVICE_PORT: int = 8000
Expand Down
67 changes: 67 additions & 0 deletions tests/model_serving/model_runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.model_serving.model_runtime.model_validation.constant import (
COMPLETION_QUERY,
EMBEDDING_QUERY,
OPENAI_ENDPOINT_NAME,
AUDIO_FILE_URL,
AUDIO_FILE_LOCAL_PATH,
Expand Down Expand Up @@ -35,6 +36,13 @@ def validate_audio_inference_output(model_info: Any, completion_responses: Itera
assert len(completion_responses) > 0, "Completion responses should not be empty"


def validate_embedding_inference_output(model_info: Any, embedding_responses: Iterable[Any]) -> None:
assert model_info is not None, "Model info should not be None"
assert isinstance(model_info, (list, tuple)), "Model info should be a list or tuple"
assert isinstance(embedding_responses, (list, tuple)), "Embedding responses should be a list or tuple"
assert len(embedding_responses) > 0, "Embedding responses should not be empty"


def fetch_tgis_response( # type: ignore
url: str,
model_name: str,
Expand Down Expand Up @@ -87,6 +95,55 @@ def run_raw_inference(
raise NotSupportedError(f"{endpoint} endpoint")


@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=6))
def run_embedding_inference(
endpoint: str,
model_name: str,
url: Optional[str] = None,
pod_name: Optional[str] = None,
isvc: Optional[InferenceService] = None,
port: Optional[int] = Ports.REST_PORT,
embedding_query: list[dict[str, str]] = EMBEDDING_QUERY,
) -> tuple[Any, list[Any]]:
LOGGER.info(pod_name)
Comment thread
dbasunag marked this conversation as resolved.
Outdated
if url is not None:
LOGGER.info("Using provided URL for inference: %s", url)
inference_client = OpenAIClient(host=url, model_name=model_name, streaming=True)
embedding_responses = []
for query in embedding_query:
print(f"Sending embedding request for query: {query}")
Comment thread
dbasunag marked this conversation as resolved.
Outdated
embedding_response = inference_client.request_http(
endpoint=OpenAIEnpoints.EMBEDDINGS,
query=query,
)
embedding_responses.append(embedding_response)
model_info = OpenAIClient.get_request_http(host=url, endpoint=OpenAIEnpoints.MODELS_INFO)
return model_info, embedding_responses
else:
LOGGER.info("Using port forwarding for inference on pod: %s", pod_name)
if pod_name is None or isvc is None or port is None:
raise ValueError("pod_name, isvc, and port are required when url is not provided")

with portforward.forward(
pod_or_service=pod_name,
namespace=isvc.namespace,
from_port=port,
to_port=port,
):
if endpoint == "openai":
embedding_responses = []
inference_client = OpenAIClient(host=f"http://localhost:{port}", model_name=model_name, streaming=True)
for query in embedding_query:
embedding_response = inference_client.request_http(endpoint=OpenAIEnpoints.EMBEDDINGS, query=query)
embedding_responses.append(embedding_response)
model_info = OpenAIClient.get_request_http(
host=f"http://localhost:{port}", endpoint=OpenAIEnpoints.MODELS_INFO
)
return model_info, embedding_responses
else:
raise NotSupportedError(f"{endpoint} endpoint for embedding inference")


@retry(stop=stop_after_attempt(5), wait=wait_exponential(min=1, max=6))
def run_audio_inference(
endpoint: str,
Expand Down Expand Up @@ -177,6 +234,16 @@ def validate_raw_openai_inference_request(
completion_responses,
response_snapshot=response_snapshot,
)
elif model_output_type == "embedding":
model_info, embedding_responses = run_embedding_inference(
pod_name=pod_name,
isvc=isvc,
port=port,
endpoint=OPENAI_ENDPOINT_NAME,
embedding_query=EMBEDDING_QUERY,
model_name=model_name,
)
validate_embedding_inference_output(model_info=model_info, embedding_responses=embedding_responses)

else:
raise NotSupportedError(f"Model output type {model_output_type} is not supported for raw inference request.")
Expand Down
6 changes: 2 additions & 4 deletions utilities/plugins/openai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ def _construct_request_data(
elif OpenAIEnpoints.EMBEDDINGS in endpoint:
data = {
"input": query["text"],
"encoding_format": 0.1,
"temperature": 0,
}
else:
data = {"prompt": query["text"], "temperature": 0, "top_p": 0.9, "seed": 1037, "stream": streaming}
Expand Down Expand Up @@ -230,8 +228,8 @@ def _parse_response(self, endpoint: str, message: dict[str, Any]) -> Any:
LOGGER.info(message["choices"][0])
return message["choices"][0]
elif OpenAIEnpoints.EMBEDDINGS in endpoint:
LOGGER.info(message["choices"][0])
return message["choices"][0]
LOGGER.info(message["data"][0])
return message["data"][0]
elif OpenAIEnpoints.AUDIO_TRANSCRIPTION in endpoint:
LOGGER.info(message["text"])
return message["text"]
Expand Down