Skip to content

Commit 3583a74

Browse files
committed
feat: add nim image retrieval endpoint support
1 parent f6f4e34 commit 3583a74

33 files changed

+1340
-276
lines changed

src/aiperf/common/enums/metric_enums.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ class GenericMetricUnit(BaseMetricUnit):
188188
RATIO = _unit("ratio")
189189
USER = _unit("user")
190190
PERCENT = _unit("%")
191+
IMAGE = _unit("image")
192+
IMAGES = _unit("images")
193+
VIDEO = _unit("video")
194+
VIDEOS = _unit("videos")
191195

192196

193197
class PowerMetricUnitInfo(BaseMetricUnitInfo):
@@ -289,7 +293,11 @@ class MetricOverTimeUnitInfo(BaseMetricUnitInfo):
289293
@model_validator(mode="after")
290294
def _set_tag(self: Self) -> Self:
291295
"""Set the tag based on the existing units. ie. requests/sec, tokens/sec, etc."""
292-
self.tag = f"{self.primary_unit}/{self.time_unit}"
296+
self.tag = (
297+
f"{self.primary_unit}/{self.time_unit}"
298+
if not self.inverted
299+
else f"{self.time_unit}/{self.primary_unit}"
300+
)
293301
if self.third_unit:
294302
# If there is a third unit, add it to the tag. ie. tokens/sec/user
295303
self.tag += f"/{self.third_unit}"
@@ -302,6 +310,7 @@ def _set_tag(self: Self) -> Self:
302310
primary_unit: "MetricUnitT"
303311
time_unit: MetricTimeUnit | MetricTimeUnitInfo
304312
third_unit: "MetricUnitT | None" = None
313+
inverted: bool = False
305314

306315
def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float:
307316
"""Convert a value from this unit to another unit."""
@@ -342,6 +351,24 @@ class MetricOverTimeUnit(BaseMetricUnit):
342351
time_unit=MetricTimeUnit.SECONDS,
343352
third_unit=GenericMetricUnit.USER,
344353
)
354+
IMAGES_PER_SECOND = MetricOverTimeUnitInfo(
355+
primary_unit=GenericMetricUnit.IMAGES,
356+
time_unit=MetricTimeUnit.SECONDS,
357+
)
358+
MS_PER_IMAGE = MetricOverTimeUnitInfo(
359+
time_unit=MetricTimeUnit.MILLISECONDS,
360+
primary_unit=GenericMetricUnit.IMAGE,
361+
inverted=True,
362+
)
363+
VIDEOS_PER_SECOND = MetricOverTimeUnitInfo(
364+
primary_unit=GenericMetricUnit.VIDEOS,
365+
time_unit=MetricTimeUnit.SECONDS,
366+
)
367+
MS_PER_VIDEO = MetricOverTimeUnitInfo(
368+
time_unit=MetricTimeUnit.MILLISECONDS,
369+
primary_unit=GenericMetricUnit.VIDEO,
370+
inverted=True,
371+
)
345372

346373
@cached_property
347374
def info(self) -> MetricOverTimeUnitInfo:
@@ -363,6 +390,11 @@ def third_unit(self) -> "MetricUnitT | None":
363390
"""Get the third unit (if applicable)."""
364391
return self.info.third_unit
365392

393+
@cached_property
394+
def inverted(self) -> bool:
395+
"""Whether the metric is inverted (e.g. time / metric)."""
396+
return self.info.inverted
397+
366398

367399
class MetricType(CaseInsensitiveStrEnum):
368400
"""Defines the possible types of metrics."""
@@ -643,6 +675,9 @@ class MetricFlags(Flag):
643675
TOKENIZES_INPUT_ONLY = 1 << 12
644676
"""Metrics that are only applicable when the endpoint tokenizes input text."""
645677

678+
SUPPORTS_VIDEO_ONLY = 1 << 13
679+
"""Metrics that are only applicable to video-based endpoints."""
680+
646681
def has_flags(self, flags: "MetricFlags") -> bool:
647682
"""Return True if the metric has ALL of the given flag(s) (regardless of other flags)."""
648683
# Bitwise AND will return the input flags only if all of the given flags are present.

src/aiperf/common/enums/plugin_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class EndpointType(CaseInsensitiveStrEnum):
3232
NIM_RANKINGS = "nim_rankings"
3333
SOLIDO_RAG = "solido_rag"
3434
TEMPLATE = "template"
35+
IMAGE_RETRIEVAL = "image_retrieval"
3536

3637

3738
class TransportType(CaseInsensitiveStrEnum):

src/aiperf/common/messages/base_messages.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,28 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
import json
43
import time
54
from typing import ClassVar
65

76
from pydantic import Field
87

98
from aiperf.common.aiperf_logger import AIPerfLogger
109
from aiperf.common.enums.message_enums import MessageType
11-
from aiperf.common.models.base_models import AIPerfBaseModel, exclude_if_none
10+
from aiperf.common.models.base_models import AIPerfBaseModel
1211
from aiperf.common.models.error_models import ErrorDetails
1312
from aiperf.common.types import MessageTypeT
13+
from aiperf.common.utils import load_json_str
1414

1515
_logger = AIPerfLogger(__name__)
1616

1717

18-
@exclude_if_none("request_ns", "request_id")
1918
class Message(AIPerfBaseModel):
20-
"""Base message class for optimized message handling. Based on the AIPerfBaseModel class,
21-
so it supports @exclude_if_none decorator. see :class:`AIPerfBaseModel` for more details.
19+
"""Base message class for optimized message handling.
2220
2321
This class provides a base for all messages, including common fields like message_type,
24-
request_ns, and request_id. It also supports optional field exclusion based on the
25-
@exclude_if_none decorator.
22+
request_ns, and request_id.
2623
2724
Each message model should inherit from this class, set the message_type field,
2825
and define its own additional fields.
29-
30-
Example:
31-
```python
32-
@exclude_if_none("some_field")
33-
class ExampleMessage(Message):
34-
some_field: int | None = Field(default=None)
35-
other_field: int = Field(default=1)
36-
```
3726
"""
3827

3928
_message_type_lookup: ClassVar[dict[MessageTypeT, type["Message"]]] = {}
@@ -71,7 +60,7 @@ def __get_validators__(cls):
7160
def from_json(cls, json_str: str | bytes | bytearray) -> "Message":
7261
"""Deserialize a message from a JSON string, attempting to auto-detect the message type.
7362
NOTE: If you already know the message type, use the more performant :meth:`from_json_with_type` instead."""
74-
data = json.loads(json_str)
63+
data = load_json_str(json_str)
7564
message_type = data.get("message_type")
7665
if not message_type:
7766
raise ValueError(f"Missing message_type: {json_str}")
@@ -97,7 +86,7 @@ def from_json_with_type(
9786
return message_class.model_validate_json(json_str)
9887

9988
def __str__(self) -> str:
100-
return self.model_dump_json()
89+
return self.model_dump_json(exclude_none=True)
10190

10291

10392
class RequiresRequestNSMixin(Message):

src/aiperf/common/messages/command_messages.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
import json
43
import uuid
54
from typing import Any, ClassVar
65

@@ -18,14 +17,13 @@
1817
from aiperf.common.models import (
1918
ErrorDetails,
2019
ProcessRecordsResult,
21-
exclude_if_none,
2220
)
2321
from aiperf.common.types import CommandTypeT, MessageTypeT, ServiceTypeT
22+
from aiperf.common.utils import load_json_str
2423

2524
_logger = AIPerfLogger(__name__)
2625

2726

28-
@exclude_if_none("target_service_id", "target_service_type")
2927
class TargetedServiceMessage(BaseServiceMessage):
3028
"""Message that can be targeted to a specific service by id or type.
3129
If both `target_service_type` and `target_service_id` are None, the message is
@@ -80,7 +78,7 @@ def __init_subclass__(cls, **kwargs):
8078
@classmethod
8179
def from_json(cls, json_str: str | bytes | bytearray) -> "CommandMessage":
8280
"""Deserialize a command message from a JSON string, attempting to auto-detect the command type."""
83-
data = json.loads(json_str)
81+
data = load_json_str(json_str)
8482
command_type = data.get("command")
8583
if not command_type:
8684
raise ValueError(f"Missing command: {json_str}")
@@ -139,7 +137,7 @@ def __init_subclass__(cls, **kwargs):
139137
@classmethod
140138
def from_json(cls, json_str: str | bytes | bytearray) -> "CommandResponse":
141139
"""Deserialize a command response message from a JSON string, attempting to auto-detect the command response type."""
142-
data = json.loads(json_str)
140+
data = load_json_str(json_str)
143141
status = data.get("status")
144142
if not status:
145143
raise ValueError(f"Missing command response status: {json_str}")
@@ -259,7 +257,6 @@ class SpawnWorkersCommand(CommandMessage):
259257
num_workers: int = Field(..., description="Number of workers to spawn")
260258

261259

262-
@exclude_if_none("worker_ids", "num_workers")
263260
class ShutdownWorkersCommand(CommandMessage):
264261
command: CommandTypeT = CommandType.SHUTDOWN_WORKERS
265262

src/aiperf/common/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
########################################################################
1111
from aiperf.common.models.base_models import (
1212
AIPerfBaseModel,
13-
exclude_if_none,
1413
)
1514
from aiperf.common.models.credit_models import (
1615
CreditPhaseConfig,
@@ -69,6 +68,7 @@
6968
BaseInferenceServerResponse,
7069
BaseResponseData,
7170
EmbeddingResponseData,
71+
ImageRetrievalResponseData,
7272
MetricRecordInfo,
7373
MetricRecordMetadata,
7474
MetricResult,
@@ -144,6 +144,7 @@
144144
"GpuTelemetrySnapshot",
145145
"IOCounters",
146146
"Image",
147+
"ImageRetrievalResponseData",
147148
"InputsFile",
148149
"JsonExportData",
149150
"JsonMetricResult",
@@ -194,6 +195,5 @@
194195
"WorkerTaskStats",
195196
"create_balanced_distribution",
196197
"create_uniform_distribution",
197-
"exclude_if_none",
198198
"logger",
199199
]
Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,14 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
from typing import Any, ClassVar
4-
5-
from pydantic import BaseModel, ConfigDict, model_serializer
6-
7-
from aiperf.common.types import AIPerfBaseModelT
8-
9-
10-
def exclude_if_none(*field_names: str):
11-
"""Decorator to set the _exclude_if_none_fields class attribute to the set of
12-
field names that should be excluded if they are None.
13-
"""
14-
15-
def decorator(model: type[AIPerfBaseModelT]) -> type[AIPerfBaseModelT]:
16-
# This attribute is defined by the AIPerfBaseModel class.
17-
if not hasattr(model, "_exclude_if_none_fields"):
18-
model._exclude_if_none_fields = set()
19-
model._exclude_if_none_fields.update(set(field_names))
20-
return model
21-
22-
return decorator
3+
from pydantic import BaseModel, ConfigDict
234

245

256
class AIPerfBaseModel(BaseModel):
26-
"""Base model for all AIPerf Pydantic models. This class is configured to allow
27-
arbitrary types to be used as fields as to allow for more flexible model definitions
28-
by end users without breaking the existing code.
29-
30-
The @exclude_if_none decorator can also be used to specify which fields
31-
should be excluded from the serialized model if they are None. This is a workaround
32-
for the fact that pydantic does not support specifying exclude_none on a per-field basis.
33-
"""
7+
"""Base model for all AIPerf Pydantic models.
348
35-
_exclude_if_none_fields: ClassVar[set[str]] = set()
36-
"""Set of field names that should be excluded from the serialized model if they
37-
are None. This is set by the @exclude_if_none decorator.
9+
This class is configured to allow arbitrary types to be used as fields
10+
to allow for more flexible model definitions by end users without breaking
11+
existing code.
3812
"""
3913

40-
# Allow extras by default to be more flexible for end users
4114
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
42-
43-
@model_serializer
44-
def _serialize_model(self) -> dict[str, Any]:
45-
"""Serialize the model to a dictionary.
46-
47-
This method overrides the default serializer to exclude fields that with a
48-
value of None and were marked with the @exclude_if_none decorator.
49-
"""
50-
return {
51-
k: v
52-
for k, v in self
53-
if not (k in self._exclude_if_none_fields and v is None)
54-
}

src/aiperf/common/models/dataset_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import Field
77

88
from aiperf.common.enums import MediaType
9-
from aiperf.common.models.base_models import AIPerfBaseModel, exclude_if_none
9+
from aiperf.common.models.base_models import AIPerfBaseModel
1010
from aiperf.common.types import MediaTypeT
1111

1212

@@ -45,7 +45,6 @@ class Video(Media):
4545
media_type: ClassVar[MediaTypeT] = MediaType.VIDEO
4646

4747

48-
@exclude_if_none("role")
4948
class Turn(AIPerfBaseModel):
5049
"""A dataset representation of a single turn within a conversation.
5150

src/aiperf/common/models/export_models.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33

44
from datetime import datetime
55

6-
from pydantic import ConfigDict, Field
6+
from pydantic import BaseModel, ConfigDict, Field
77

88
from aiperf.common.config import UserConfig
99
from aiperf.common.models import ErrorDetailsCount
10-
from aiperf.common.models.base_models import AIPerfBaseModel, exclude_if_none
10+
from aiperf.common.models.base_models import AIPerfBaseModel
1111

1212

13-
@exclude_if_none(
14-
"min", "max", "p1", "p5", "p10", "p25", "p50", "p75", "p90", "p95", "p99", "std"
15-
)
16-
class JsonMetricResult(AIPerfBaseModel):
13+
class JsonMetricResult(BaseModel):
1714
"""The result values of a single metric for JSON export.
1815
1916
NOTE:
@@ -70,7 +67,7 @@ class TelemetryExportData(AIPerfBaseModel):
7067
endpoints: dict[str, EndpointData]
7168

7269

73-
class JsonExportData(AIPerfBaseModel):
70+
class JsonExportData(BaseModel):
7471
"""Summary data to be exported to a JSON file.
7572
7673
NOTE:

src/aiperf/common/models/record_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,18 @@ class RankingsResponseData(BaseResponseData):
622622
)
623623

624624

625+
class ImageRetrievalResponseData(BaseResponseData):
626+
"""Parsed image retrieval response data."""
627+
628+
data: list[dict[str, Any]] = Field(
629+
..., description="The image retrieval data from the response."
630+
)
631+
632+
def get_text(self) -> str:
633+
"""Get the text of the response (empty for image retrieval)."""
634+
return ""
635+
636+
625637
class ParsedResponse(AIPerfBaseModel):
626638
"""Parsed response from a inference client."""
627639

@@ -633,6 +645,7 @@ class ParsedResponse(AIPerfBaseModel):
633645
| TextResponseData
634646
| EmbeddingResponseData
635647
| RankingsResponseData
648+
| ImageRetrievalResponseData
636649
| BaseResponseData
637650
| None
638651
] = Field(

src/aiperf/common/models/sequence_distribution.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@
3232

3333
from __future__ import annotations
3434

35-
import json
3635
import logging
3736
import re
3837
from dataclasses import dataclass
3938

4039
import numpy as np
40+
import orjson
41+
42+
from aiperf.common.utils import load_json_str
4143

4244
logger = logging.getLogger(__name__)
4345

@@ -337,9 +339,9 @@ def parse(cls, dist_str: str) -> SequenceLengthDistribution:
337339
def _parse_json_format(cls, json_str: str) -> SequenceLengthDistribution:
338340
"""Parse JSON format: {"pairs": [{"isl": 256, "isl_stddev": 10, "osl": 128, "osl_stddev": 5, "prob": 40}, ...]}"""
339341
try:
340-
data = json.loads(json_str)
341-
except json.JSONDecodeError as e:
342-
raise ValueError(f"Invalid JSON format: {e}") from None
342+
data = load_json_str(json_str)
343+
except orjson.JSONDecodeError as e:
344+
raise ValueError(f"Invalid JSON format: {e}") from e
343345

344346
# Validate structure outside the JSON parsing try-catch
345347
if "pairs" not in data:

0 commit comments

Comments
 (0)