Skip to content

Commit 2a3ec4c

Browse files
committed
feat: Add image_retrieval endpoint and file loading images/video
1 parent e5194fa commit 2a3ec4c

26 files changed

+1317
-83
lines changed

src/aiperf/common/enums/metric_enums.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ class GenericMetricUnit(BaseMetricUnit):
188188
RATIO = _unit("ratio")
189189
USER = _unit("user")
190190
PERCENT = _unit("%")
191+
IMAGE = _unit("image")
192+
IMAGES = _unit("images")
193+
VIDEO = _unit("video")
194+
VIDEOS = _unit("videos")
191195

192196

193197
class PowerMetricUnitInfo(BaseMetricUnitInfo):
@@ -289,7 +293,11 @@ class MetricOverTimeUnitInfo(BaseMetricUnitInfo):
289293
@model_validator(mode="after")
290294
def _set_tag(self: Self) -> Self:
291295
"""Set the tag based on the existing units. ie. requests/sec, tokens/sec, etc."""
292-
self.tag = f"{self.primary_unit}/{self.time_unit}"
296+
self.tag = (
297+
f"{self.primary_unit}/{self.time_unit}"
298+
if not self.inverted
299+
else f"{self.time_unit}/{self.primary_unit}"
300+
)
293301
if self.third_unit:
294302
# If there is a third unit, add it to the tag. ie. tokens/sec/user
295303
self.tag += f"/{self.third_unit}"
@@ -302,6 +310,7 @@ def _set_tag(self: Self) -> Self:
302310
primary_unit: "MetricUnitT"
303311
time_unit: MetricTimeUnit | MetricTimeUnitInfo
304312
third_unit: "MetricUnitT | None" = None
313+
inverted: bool = False
305314

306315
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
307316
"""Convert a value from this unit to another unit."""
@@ -342,6 +351,24 @@ class MetricOverTimeUnit(BaseMetricUnit):
342351
time_unit=MetricTimeUnit.SECONDS,
343352
third_unit=GenericMetricUnit.USER,
344353
)
354+
IMAGES_PER_SECOND = MetricOverTimeUnitInfo(
355+
primary_unit=GenericMetricUnit.IMAGES,
356+
time_unit=MetricTimeUnit.SECONDS,
357+
)
358+
MS_PER_IMAGE = MetricOverTimeUnitInfo(
359+
time_unit=MetricTimeUnit.MILLISECONDS,
360+
primary_unit=GenericMetricUnit.IMAGE,
361+
inverted=True,
362+
)
363+
VIDEOS_PER_SECOND = MetricOverTimeUnitInfo(
364+
primary_unit=GenericMetricUnit.VIDEOS,
365+
time_unit=MetricTimeUnit.SECONDS,
366+
)
367+
MS_PER_VIDEO = MetricOverTimeUnitInfo(
368+
time_unit=MetricTimeUnit.MILLISECONDS,
369+
primary_unit=GenericMetricUnit.VIDEO,
370+
inverted=True,
371+
)
345372

346373
@cached_property
347374
def info(self) -> MetricOverTimeUnitInfo:
@@ -363,6 +390,11 @@ def third_unit(self) -> "MetricUnitT | None":
363390
"""Get the third unit (if applicable)."""
364391
return self.info.third_unit
365392

393+
@cached_property
394+
def inverted(self) -> bool:
395+
"""Whether the metric is inverted (e.g. time / metric)."""
396+
return self.info.inverted
397+
366398

367399
class MetricType(CaseInsensitiveStrEnum):
368400
"""Defines the possible types of metrics."""
@@ -643,6 +675,9 @@ class MetricFlags(Flag):
643675
TOKENIZES_INPUT_ONLY = 1 << 12
644676
"""Metrics that are only applicable when the endpoint tokenizes input text."""
645677

678+
SUPPORTS_VIDEO_ONLY = 1 << 13
679+
"""Metrics that are only applicable to video-based endpoints."""
680+
646681
def has_flags(self, flags: "MetricFlags") -> bool:
647682
"""Return True if the metric has ALL of the given flag(s) (regardless of other flags)."""
648683
# Bitwise AND will return the input flags only if all of the given flags are present.

src/aiperf/common/enums/plugin_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class EndpointType(CaseInsensitiveStrEnum):
3232
NIM_RANKINGS = "nim_rankings"
3333
SOLIDO_RAG = "solido_rag"
3434
TEMPLATE = "template"
35+
IMAGE_RETRIEVAL = "image_retrieval"
3536

3637

3738
class TransportType(CaseInsensitiveStrEnum):

src/aiperf/common/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
BaseInferenceServerResponse,
7474
BaseResponseData,
7575
EmbeddingResponseData,
76+
ImageRetrievalResponseData,
7677
MetricRecordInfo,
7778
MetricRecordMetadata,
7879
MetricResult,
@@ -149,6 +150,7 @@
149150
"GpuTelemetrySnapshot",
150151
"IOCounters",
151152
"Image",
153+
"ImageRetrievalResponseData",
152154
"InputsFile",
153155
"JsonExportData",
154156
"JsonMetricResult",

src/aiperf/common/models/record_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,18 @@ class RankingsResponseData(BaseResponseData):
625625
)
626626

627627

628+
class ImageRetrievalResponseData(BaseResponseData):
629+
"""Parsed image retrieval response data."""
630+
631+
data: list[dict[str, Any]] = Field(
632+
..., description="The image retrieval data from the response."
633+
)
634+
635+
def get_text(self) -> str:
636+
"""Get the text of the response (empty for image retrieval)."""
637+
return ""
638+
639+
628640
class ParsedResponse(AIPerfBaseModel):
629641
"""Parsed response from a inference client."""
630642

@@ -636,6 +648,7 @@ class ParsedResponse(AIPerfBaseModel):
636648
| TextResponseData
637649
| EmbeddingResponseData
638650
| RankingsResponseData
651+
| ImageRetrievalResponseData
639652
| BaseResponseData
640653
| None
641654
] = Field(

src/aiperf/dataset/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@
4646
)
4747
from aiperf.dataset.utils import (
4848
check_file_exists,
49+
encode_audio,
4950
encode_image,
51+
encode_video,
52+
open_audio,
5053
open_image,
54+
open_video,
5155
)
5256

5357
__all__ = [
@@ -84,7 +88,11 @@
8488
"SyntheticDatasetComposer",
8589
"VideoGenerator",
8690
"check_file_exists",
91+
"encode_audio",
8792
"encode_image",
93+
"encode_video",
8894
"main",
95+
"open_audio",
8996
"open_image",
97+
"open_video",
9098
]

src/aiperf/dataset/dataset_manager.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
CommAddress,
1515
CommandType,
1616
ComposerType,
17+
CustomDatasetType,
1718
MessageType,
1819
ServiceType,
1920
)
2021
from aiperf.common.environment import Environment
2122
from aiperf.common.factories import (
2223
ComposerFactory,
2324
DatasetSamplingStrategyFactory,
25+
EndpointFactory,
2426
ServiceFactory,
2527
)
2628
from aiperf.common.hooks import on_command, on_request
@@ -82,11 +84,21 @@ async def _profile_configure_command(
8284
) -> None:
8385
"""Configure the dataset."""
8486

85-
self.info("Configuring tokenizer(s) for dataset manager")
86-
begin = time.perf_counter()
87-
await self._configure_tokenizer()
88-
duration = time.perf_counter() - begin
89-
self.info(lambda: f"Tokenizer(s) configured in {duration:.2f} seconds")
87+
endpoint_meta = EndpointFactory.get_metadata(self.user_config.endpoint.type)
88+
if (
89+
endpoint_meta.tokenizes_input
90+
or self.user_config.input.custom_dataset_type
91+
== CustomDatasetType.MOONCAKE_TRACE
92+
):
93+
self.info("Configuring tokenizer(s) for dataset manager")
94+
begin = time.perf_counter()
95+
await self._configure_tokenizer()
96+
duration = time.perf_counter() - begin
97+
self.info(lambda: f"Tokenizer(s) configured in {duration:.2f} seconds")
98+
else:
99+
self.info(
100+
"Endpoint does not tokenize input, skipping tokenizer configuration"
101+
)
90102

91103
self.info(lambda: f"Configuring dataset for {self.service_id}")
92104
begin = time.perf_counter()

src/aiperf/dataset/generator/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, config: PromptConfig, tokenizer: Tokenizer, **kwargs):
4949

5050
# TODO: move this under initialize() method
5151
# Initialize corpus if not already done
52-
if self._tokenized_corpus is None:
52+
if self._tokenized_corpus is None and tokenizer:
5353
self._initialize_corpus()
5454

5555
# Initialize prefix prompts pool if the pool size > 0

src/aiperf/dataset/loader/mixins.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from collections.abc import Iterable
5+
from urllib.parse import urlparse
56

7+
from aiperf.common.enums.dataset_enums import AudioFormat
8+
from aiperf.common.enums.media_enums import MediaType
69
from aiperf.common.models import Media
7-
from aiperf.common.types import MediaT
10+
from aiperf.common.types import MediaT, MediaTypeT
11+
from aiperf.dataset import utils
812
from aiperf.dataset.loader.models import CustomDatasetT
913

1014

@@ -51,8 +55,8 @@ def _convert_to_media_objects(
5155
5256
Args:
5357
data: The custom dataset to construct media objects from.
54-
media_class: The target media class (Text, Image, or Audio).
55-
field: The name of the field (e.g., 'text', 'image', 'audio').
58+
media_class: The target media class (Text, Image, Audio, or Video).
59+
field: The name of the field (e.g., 'text', 'image', 'audio', 'video').
5660
name: The name of the media field.
5761
5862
Returns:
@@ -61,6 +65,9 @@ def _convert_to_media_objects(
6165
# Check singular field first
6266
value = getattr(data, field, None)
6367
if value is not None:
68+
# Handle media content (encode local files to base64)
69+
if field in [MediaType.IMAGE, MediaType.VIDEO, MediaType.AUDIO]:
70+
value = self._handle_media_content(value, media_type=MediaType(field))
6471
return [media_class(name=name, contents=[value])]
6572

6673
# Check plural field
@@ -72,4 +79,124 @@ def _convert_to_media_objects(
7279
if all(isinstance(v, media_class) for v in values):
7380
return values
7481

82+
# Handle media content (encode local files to base64)
83+
if field in [MediaType.IMAGE, MediaType.VIDEO, MediaType.AUDIO]:
84+
values = [
85+
self._handle_media_content(v, media_type=MediaType(field))
86+
for v in values
87+
]
88+
7589
return [media_class(name=name, contents=values)]
90+
91+
def _is_url(self, content: str) -> bool:
92+
"""Check if content is a valid URL with scheme and netloc.
93+
94+
Args:
95+
content: The content to check.
96+
97+
Returns:
98+
True if content is a URL, False otherwise.
99+
100+
Raises:
101+
ValueError: If URL has only scheme or only netloc (invalid).
102+
"""
103+
url = urlparse(content)
104+
105+
# Valid URL with both scheme and netloc
106+
if url.scheme and url.netloc:
107+
return True
108+
109+
# Invalid URL - has one but not both
110+
if url.scheme or url.netloc:
111+
raise ValueError(f"Valid URL must have both a scheme and netloc: {content}")
112+
113+
# Not a URL
114+
return False
115+
116+
def _is_already_encoded(self, content: str, media_type: MediaTypeT) -> bool:
117+
"""Check if content is already encoded in the expected format.
118+
119+
Args:
120+
content: The content to check.
121+
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).
122+
123+
Returns:
124+
True if content is already encoded, False otherwise.
125+
"""
126+
url = urlparse(content)
127+
128+
if media_type in [MediaType.IMAGE, MediaType.VIDEO]:
129+
# Check for data URL format
130+
return url.scheme == "data"
131+
132+
elif media_type == MediaType.AUDIO:
133+
# Check for "format,base64" format
134+
if "," in content and not url.scheme:
135+
parts = content.split(",", 1)
136+
return len(parts) == 2 and parts[0].lower() in [
137+
AudioFormat.WAV,
138+
AudioFormat.MP3,
139+
]
140+
return False
141+
142+
return False
143+
144+
def _encode_media_file(self, content: str, media_type: MediaTypeT) -> str:
145+
"""Encode a local media file to base64.
146+
147+
Args:
148+
content: The file path to encode.
149+
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).
150+
151+
Returns:
152+
The base64-encoded content in the appropriate format.
153+
154+
Raises:
155+
FileNotFoundError: If the file doesn't exist.
156+
RuntimeError: If the format is unsupported.
157+
"""
158+
if media_type == MediaType.IMAGE:
159+
img = utils.open_image(content)
160+
img_base64 = utils.encode_image(img, img.format)
161+
return f"data:image/{img.format.lower()};base64,{img_base64}"
162+
163+
elif media_type == MediaType.AUDIO:
164+
audio_bytes, audio_format = utils.open_audio(content)
165+
return utils.encode_audio(audio_bytes, audio_format)
166+
167+
elif media_type == MediaType.VIDEO:
168+
video_bytes, video_format = utils.open_video(content)
169+
return utils.encode_video(video_bytes, video_format)
170+
171+
raise ValueError(f"Unsupported media type: {media_type}")
172+
173+
def _handle_media_content(self, content: str, media_type: MediaTypeT) -> str:
174+
"""Generic handler for media content encoding.
175+
176+
If the content is a URL, it's returned as-is.
177+
If it's already encoded, it's returned as-is.
178+
If it's a local file path, it's loaded and encoded to base64.
179+
180+
Args:
181+
content: The media content - URL, encoded string, or local file path.
182+
media_type: The media type (MediaType.IMAGE, MediaType.AUDIO, MediaType.VIDEO).
183+
184+
Returns:
185+
The processed media content.
186+
187+
Raises:
188+
FileNotFoundError: If the local file doesn't exist.
189+
RuntimeError: If the media format is unsupported.
190+
ValueError: If URL format is invalid.
191+
"""
192+
# Check if it's already encoded first (before URL check)
193+
# This handles data URLs which have a scheme but no netloc
194+
if self._is_already_encoded(content, media_type):
195+
return content
196+
197+
# Check if it's a URL
198+
if self._is_url(content):
199+
return content
200+
201+
# Otherwise, it's a local file path - encode it
202+
return self._encode_media_file(content, media_type)

0 commit comments

Comments
 (0)