Skip to content

Commit b9dd5ad

Browse files
[Automated Commit] Format Codebase
1 parent ee91e7f commit b9dd5ad

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def benchmark(
7070
)
7171
logger.info("Running VL2L benchmark with settings: {}", settings)
7272
logger.info("Running VL2L benchmark with dataset: {}", dataset)
73-
logger.info("Running VL2L benchmark with OpenAI API endpoint: {}", endpoint)
73+
logger.info(
74+
"Running VL2L benchmark with OpenAI API endpoint: {}",
75+
endpoint)
7476
logger.info("Running VL2L benchmark with random seed: {}", random_seed)
7577
test_settings, log_settings = settings.to_lgtype()
7678
task = ShopifyGlobalCatalogue(

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/evaluation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def get_hierarchical_components(
5151
intersection_count = 0
5252

5353
# Iterate through the paths simultaneously
54-
for pred_cat, true_cat in zip(predicted_categories, true_categories, strict=False):
54+
for pred_cat, true_cat in zip(
55+
predicted_categories, true_categories, strict=False):
5556
if pred_cat == true_cat:
5657
intersection_count += 1
5758
else:

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/schema.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
196196
mode="before",
197197
)
198198
@classmethod
199-
def parse_timedelta(cls, value: timedelta | float | str) -> timedelta | str:
199+
def parse_timedelta(cls, value: timedelta | float |
200+
str) -> timedelta | str:
200201
"""Parse timedelta from seconds (int/float/str) or ISO 8601 format."""
201202
if isinstance(value, timedelta):
202203
return value
@@ -222,9 +223,12 @@ def to_lgtype(self) -> lg.TestSettings:
222223
settings.server_target_latency_ns = round(
223224
self.server_target_latency.total_seconds() * 1e9,
224225
)
225-
settings.ttft_latency = round(self.server_ttft_latency.total_seconds() * 1e9)
226-
settings.tpot_latency = round(self.server_tpot_latency.total_seconds() * 1e9)
227-
settings.min_duration_ms = round(self.min_duration.total_seconds() * 1000)
226+
settings.ttft_latency = round(
227+
self.server_ttft_latency.total_seconds() * 1e9)
228+
settings.tpot_latency = round(
229+
self.server_tpot_latency.total_seconds() * 1e9)
230+
settings.min_duration_ms = round(
231+
self.min_duration.total_seconds() * 1000)
228232
settings.min_query_count = self.min_query_count
229233
settings.performance_sample_count_override = (
230234
self.performance_sample_count_override
@@ -414,5 +418,6 @@ def ensure_content_is_list(
414418
== "pydantic_core._pydantic_core"
415419
and message["content"].__class__.__name__ == "ValidatorIterator"
416420
):
417-
message["content"] = list(message["content"]) # type: ignore[arg-type]
421+
message["content"] = list(
422+
message["content"]) # type: ignore[arg-type]
418423
return messages

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/task.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def estimated_num_performance_samples(self) -> int:
145145
"""
146146
estimation_indices = random.sample(
147147
range(self.total_num_samples),
148-
k=min(MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES, self.total_num_samples),
148+
k=min(
149+
MAX_NUM_ESTIMATION_PERFORMANCE_SAMPLES,
150+
self.total_num_samples),
149151
)
150152
estimation_samples = [
151153
self.formulate_loaded_sample(self.dataset[i]) for i in estimation_indices
@@ -205,7 +207,8 @@ def _unload_samples_from_ram(query_sample_indices: list[int]) -> None:
205207
_unload_samples_from_ram,
206208
)
207209

208-
async def _query_endpoint_async_batch(self, query_sample: lg.QuerySample) -> None:
210+
async def _query_endpoint_async_batch(
211+
self, query_sample: lg.QuerySample) -> None:
209212
"""Query the endpoint through the async OpenAI API client."""
210213
try:
211214
sample = self.loaded_samples[query_sample.index]
@@ -282,7 +285,8 @@ async def _query_endpoint_async_batch(self, query_sample: lg.QuerySample) -> Non
282285
],
283286
)
284287

285-
async def _query_endpoint_async_stream(self, query_sample: lg.QuerySample) -> None:
288+
async def _query_endpoint_async_stream(
289+
self, query_sample: lg.QuerySample) -> None:
286290
"""Query the endpoint through the async OpenAI API client."""
287291
ttft_set = False
288292
try:

0 commit comments

Comments
 (0)