Skip to content

Commit 3cf0ffd

Browse files
maryamtahhanclaude
andcommitted
Add comprehensive embeddings benchmark support
Implements full embeddings benchmarking capability including schemas, quality validation (cosine similarity, MTEB), output formatters (CSV, HTML, JSON, console), mock server handler, CLI integration, and comprehensive test suite. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: Maryam Tahhan <mtahhan@redhat.com>
1 parent 5cad9dc commit 3cf0ffd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+8455
-16
lines changed

pyproject.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ dependencies = [
7474

7575
[project.optional-dependencies]
7676
# Meta Extras
77-
all = ["guidellm[perf,tokenizers,audio,vision]"]
77+
all = ["guidellm[perf,tokenizers,audio,vision,embeddings]"]
7878
recommended = ["guidellm[perf,tokenizers]"]
7979
# Feature Extras
8080
perf = ["orjson", "msgpack", "msgspec", "uvloop"]
@@ -90,6 +90,12 @@ vision = [
9090
"datasets[vision]",
9191
"pillow",
9292
]
93+
embeddings = [
94+
# Quality validation with baseline models
95+
"sentence-transformers>=2.2.0",
96+
# MTEB benchmark integration
97+
"mteb>=1.0.0",
98+
]
9399
# Dev Tooling
94100
dev = [
95101
# Install all optional dependencies
@@ -179,7 +185,9 @@ module = [
179185
"transformers.*",
180186
"setuptools.*",
181187
"setuptools_git_versioning.*",
182-
"torchcodec.*"
188+
"torchcodec.*",
189+
"sentence_transformers.*",
190+
"mteb.*"
183191
]
184192
ignore_missing_imports = true
185193

src/guidellm/__main__.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,3 +794,226 @@ def mock_server(
794794

795795
if __name__ == "__main__":
796796
cli()
797+
798+
799+
@benchmark.command(
800+
"embeddings",
801+
help=(
802+
"Run embeddings benchmark with optional quality validation. "
803+
"Supports cosine similarity validation and MTEB benchmark evaluation."
804+
),
805+
context_settings={"auto_envvar_prefix": "GUIDELLM"},
806+
)
807+
@click.option(
808+
"--target",
809+
type=str,
810+
required=True,
811+
help="Target backend URL (e.g., http://localhost:8000).",
812+
)
813+
@click.option(
814+
"--data",
815+
type=str,
816+
multiple=True,
817+
required=True,
818+
help=(
819+
"HuggingFace dataset ID, path to dataset, path to data file "
820+
"(csv/json/jsonl/txt), or synthetic data config."
821+
),
822+
)
823+
@click.option(
824+
"--profile",
825+
default="sweep",
826+
type=click.Choice(STRATEGY_PROFILE_CHOICES),
827+
help=f"Benchmark profile type. Options: {', '.join(STRATEGY_PROFILE_CHOICES)}.",
828+
)
829+
@click.option(
830+
"--rate",
831+
callback=cli_tools.parse_list_floats,
832+
multiple=True,
833+
default=None,
834+
help="Benchmark rate(s) to test. Meaning depends on profile.",
835+
)
836+
@click.option(
837+
"--backend",
838+
type=click.Choice(list(get_literal_vals(BackendType))),
839+
default="openai_http",
840+
help=f"Backend type. Options: {', '.join(get_literal_vals(BackendType))}.",
841+
)
842+
@click.option(
843+
"--backend-kwargs",
844+
callback=cli_tools.parse_json,
845+
default=None,
846+
help='JSON string of backend arguments. E.g., \'{"api_key": "key"}\'',
847+
)
848+
@click.option(
849+
"--model",
850+
default=None,
851+
type=str,
852+
help="Model ID to benchmark. If not provided, uses first available model.",
853+
)
854+
@click.option(
855+
"--request-format",
856+
default="embeddings",
857+
help="Format to use for requests (default: embeddings).",
858+
)
859+
@click.option(
860+
"--processor",
861+
default=None,
862+
type=str,
863+
help="Processor or tokenizer for token counts. If not provided, loads from model.",
864+
)
865+
@click.option(
866+
"--data-samples",
867+
default=-1,
868+
type=int,
869+
help="Number of samples from dataset. -1 (default) uses all samples.",
870+
)
871+
@click.option(
872+
"--outputs",
873+
default=["json", "csv", "html"],
874+
callback=cli_tools.parse_csv,
875+
help=(
876+
"Comma-separated list of output formats: json,csv,html,console. "
877+
"Default: json,csv,html"
878+
),
879+
)
880+
@click.option(
881+
"--output-dir",
882+
type=click.Path(file_okay=False, dir_okay=True, path_type=Path),
883+
default=Path.cwd(),
884+
help="Directory to save output files. Default: current directory.",
885+
)
886+
@click.option(
887+
"--max-requests",
888+
default=None,
889+
type=int,
890+
help="Maximum number of requests to execute.",
891+
)
892+
@click.option(
893+
"--max-errors",
894+
default=None,
895+
type=int,
896+
help="Maximum number of errors before stopping benchmark.",
897+
)
898+
@click.option(
899+
"--max-duration",
900+
default=None,
901+
type=float,
902+
help="Maximum duration in seconds for benchmark execution.",
903+
)
904+
# Embeddings-specific quality validation options
905+
@click.option(
906+
"--enable-quality-validation",
907+
is_flag=True,
908+
default=False,
909+
help="Enable quality validation using cosine similarity against baseline model.",
910+
)
911+
@click.option(
912+
"--baseline-model",
913+
default=None,
914+
type=str,
915+
help=(
916+
"HuggingFace model for baseline comparison. "
917+
"E.g., 'sentence-transformers/all-MiniLM-L6-v2'. "
918+
"Defaults to target model if not specified."
919+
),
920+
)
921+
@click.option(
922+
"--quality-tolerance",
923+
default=1e-2,
924+
type=float,
925+
help=(
926+
"Cosine similarity tolerance threshold. "
927+
"Default: 1e-2 (standard), use 5e-4 for MTEB-level validation."
928+
),
929+
)
930+
@click.option(
931+
"--enable-mteb",
932+
is_flag=True,
933+
default=False,
934+
help="Enable MTEB benchmark evaluation for standardized quality scoring.",
935+
)
936+
@click.option(
937+
"--mteb-tasks",
938+
callback=cli_tools.parse_csv,
939+
default=None,
940+
help=(
941+
"Comma-separated list of MTEB tasks. "
942+
"Default: STS12,STS13,STSBenchmark. E.g., 'STS12,STS13,STS14'"
943+
),
944+
)
945+
@click.option(
946+
"--encoding-format",
947+
type=click.Choice(["float", "base64"]),
948+
default="float",
949+
help="Embedding encoding format. Options: float, base64. Default: float.",
950+
)
951+
@click.option(
952+
"--disable-console",
953+
is_flag=True,
954+
default=False,
955+
help="Disable all console output (including progress display).",
956+
)
957+
@click.option(
958+
"--disable-console-interactive",
959+
is_flag=True,
960+
default=False,
961+
help="Disable interactive console elements (progress bar, tables).",
962+
)
963+
@click.option(
964+
"--random-seed",
965+
default=42,
966+
type=int,
967+
help="Random seed for reproducibility. Default: 42.",
968+
)
969+
def embeddings(**kwargs):
970+
"""Run embeddings benchmark with optional quality validation."""
971+
from guidellm.benchmark.embeddings_entrypoints import benchmark_embeddings
972+
from guidellm.benchmark.schemas.embeddings import BenchmarkEmbeddingsArgs
973+
974+
# Only set CLI args that differ from click defaults
975+
kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
976+
977+
# Handle console options
978+
disable_console = kwargs.pop("disable_console", False)
979+
disable_console_interactive = (
980+
kwargs.pop("disable_console_interactive", False) or disable_console
981+
)
982+
console = Console() if not disable_console else None
983+
984+
envs = cli_tools.list_set_env()
985+
if console and envs:
986+
console.print_update(
987+
title=(
988+
"Note: the following environment variables "
989+
"are set and **may** affect configuration"
990+
),
991+
details=", ".join(envs),
992+
status="warning",
993+
)
994+
995+
try:
996+
args = BenchmarkEmbeddingsArgs.create(scenario=None, **kwargs)
997+
except ValidationError as err:
998+
errs = err.errors(include_url=False, include_context=True, include_input=True)
999+
param_name = "--" + str(errs[0]["loc"][0]).replace("_", "-")
1000+
raise click.BadParameter(
1001+
errs[0]["msg"], ctx=click.get_current_context(), param_hint=param_name
1002+
) from err
1003+
1004+
if uvloop is not None:
1005+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
1006+
1007+
asyncio.run(
1008+
benchmark_embeddings(
1009+
args=args,
1010+
progress=(
1011+
GenerativeConsoleBenchmarkerProgress()
1012+
if not disable_console_interactive
1013+
else None
1014+
),
1015+
console=console,
1016+
)
1017+
)
1018+
1019+

src/guidellm/backends/openai/request_handlers.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
__all__ = [
2121
"AudioRequestHandler",
2222
"ChatCompletionsRequestHandler",
23+
"EmbeddingsRequestHandler",
2324
"OpenAIRequestHandler",
2425
"OpenAIRequestHandlerFactory",
2526
"TextCompletionsRequestHandler",
@@ -667,3 +668,113 @@ def extract_metrics(
667668
text_words=len(text.split()) if text else 0,
668669
text_characters=len(text) if text else 0,
669670
)
671+
672+
673+
@OpenAIRequestHandlerFactory.register("embeddings")
674+
class EmbeddingsRequestHandler(OpenAIRequestHandler):
675+
"""
676+
Request handler for OpenAI-style embeddings endpoints.
677+
678+
Handles embeddings requests which do not support streaming and return
679+
embedding vectors instead of generated text. Processes input text into
680+
embeddings with optional quality validation support.
681+
"""
682+
683+
def format(
684+
self,
685+
data: GenerationRequest,
686+
**kwargs,
687+
) -> GenerationRequestArguments:
688+
"""
689+
Format the embeddings generation request.
690+
691+
:param data: The generation request to format
692+
:param **kwargs: Additional keyword arguments (model, encoding_format, etc.)
693+
:return: The formatted request arguments
694+
"""
695+
arguments = GenerationRequestArguments()
696+
arguments.body = {}
697+
arguments.stream = False # Embeddings never stream
698+
699+
# Add model
700+
if kwargs.get("model") is not None:
701+
arguments.body["model"] = kwargs["model"]
702+
703+
# Build input from text columns
704+
input_texts = []
705+
for text in data.columns.get("text_column", []):
706+
if text:
707+
input_texts.append(text)
708+
709+
# Use single string if only one text, otherwise list
710+
if len(input_texts) == 1:
711+
arguments.body["input"] = input_texts[0]
712+
else:
713+
arguments.body["input"] = input_texts
714+
715+
# Add optional parameters
716+
if kwargs.get("encoding_format"):
717+
arguments.body["encoding_format"] = kwargs["encoding_format"]
718+
if kwargs.get("dimensions"):
719+
arguments.body["dimensions"] = kwargs["dimensions"]
720+
if kwargs.get("truncate_prompt_tokens"):
721+
arguments.body["truncate_prompt_tokens"] = kwargs["truncate_prompt_tokens"]
722+
723+
# Apply extra arguments
724+
if kwargs.get("extras"):
725+
arguments.body.update(kwargs["extras"])
726+
727+
return arguments
728+
729+
def compile_non_streaming(
730+
self,
731+
request: GenerationRequest,
732+
arguments: GenerationRequestArguments,
733+
response: Any,
734+
) -> GenerationResponse:
735+
"""
736+
Process a complete non-streaming embeddings API response.
737+
738+
:param request: Original generation request
739+
:param arguments: Request arguments used
740+
:param response: Raw API response data
741+
:return: GenerationResponse with embeddings data
742+
"""
743+
# Extract embeddings data
744+
embeddings_data = response.get("data", [])
745+
usage = response.get("usage", {})
746+
747+
# Build response (no text output for embeddings)
748+
return GenerationResponse(
749+
request_id=request.request_id,
750+
text="", # Embeddings don't generate text
751+
input_metrics=UsageMetrics(
752+
text_tokens=usage.get("prompt_tokens", 0),
753+
),
754+
output_metrics=UsageMetrics(
755+
text_tokens=0, # No output tokens for embeddings
756+
),
757+
)
758+
759+
def add_streaming_line(self, line: str) -> int | None:
760+
"""
761+
Embeddings do not support streaming.
762+
763+
:param line: Streaming line (unused)
764+
:return: None (not supported)
765+
:raises NotImplementedError: Embeddings never stream
766+
"""
767+
raise NotImplementedError("Embeddings do not support streaming")
768+
769+
def compile_streaming(
770+
self, request: GenerationRequest, arguments: GenerationRequestArguments
771+
) -> GenerationResponse:
772+
"""
773+
Embeddings do not support streaming.
774+
775+
:param request: Generation request (unused)
776+
:param arguments: Request arguments (unused)
777+
:return: Never returns
778+
:raises NotImplementedError: Embeddings never stream
779+
"""
780+
raise NotImplementedError("Embeddings do not support streaming")

0 commit comments

Comments
 (0)