Skip to content

Commit a75dc68

Browse files
committed
Change the default dataset repo_id to the new name of the public dataset
1 parent d9c0bcc commit a75dc68

File tree

1 file changed

+7
-13
lines changed
  • multimodal/vl2l/src/mlperf_inference_multimodal_vl2l

1 file changed

+7
-13
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/schema.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ class TestSettings(BaseModelWithAttributeDescriptionsFromDocstrings):
214214
mode="before",
215215
)
216216
@classmethod
217-
def parse_timedelta(cls, value: timedelta | float |
218-
str) -> timedelta | str:
217+
def parse_timedelta(cls, value: timedelta | float | str) -> timedelta | str:
219218
"""Parse timedelta from seconds (int/float/str) or ISO 8601 format."""
220219
if isinstance(value, timedelta):
221220
return value
@@ -241,12 +240,9 @@ def to_lgtype(self) -> lg.TestSettings:
241240
settings.server_target_latency_ns = round(
242241
self.server_target_latency.total_seconds() * 1e9,
243242
)
244-
settings.ttft_latency = round(
245-
self.server_ttft_latency.total_seconds() * 1e9)
246-
settings.tpot_latency = round(
247-
self.server_tpot_latency.total_seconds() * 1e9)
248-
settings.min_duration_ms = round(
249-
self.min_duration.total_seconds() * 1000)
243+
settings.ttft_latency = round(self.server_ttft_latency.total_seconds() * 1e9)
244+
settings.tpot_latency = round(self.server_tpot_latency.total_seconds() * 1e9)
245+
settings.min_duration_ms = round(self.min_duration.total_seconds() * 1000)
250246
settings.min_query_count = self.min_query_count
251247
settings.performance_sample_count_override = (
252248
self.performance_sample_count_override
@@ -343,7 +339,7 @@ class Model(BaseModelWithAttributeDescriptionsFromDocstrings):
343339
class Dataset(BaseModelWithAttributeDescriptionsFromDocstrings):
344340
"""Specifies a dataset on HuggingFace."""
345341

346-
repo_id: str = "Shopify/the-catalogue-public-beta"
342+
repo_id: str = "Shopify/product-catalogue"
347343
"""The HuggingFace repository ID of the dataset."""
348344

349345
token: str | None = None
@@ -454,8 +450,7 @@ def __init__(self, flag: str) -> None:
454450
class BlacklistedVllmCliFlagError(ValueError):
455451
"""The exception raised when a blacklisted vllm CLI flag is encountered."""
456452

457-
BLACKLIST: ClassVar[list[str]] = [
458-
"--model", "--host", "--port", "--api-key"]
453+
BLACKLIST: ClassVar[list[str]] = ["--model", "--host", "--port", "--api-key"]
459454

460455
def __init__(self, flag: str) -> None:
461456
"""Initialize the exception."""
@@ -508,6 +503,5 @@ def ensure_content_is_list(
508503
== "pydantic_core._pydantic_core"
509504
and message["content"].__class__.__name__ == "ValidatorIterator"
510505
):
511-
message["content"] = list(
512-
message["content"]) # type: ignore[arg-type]
506+
message["content"] = list(message["content"]) # type: ignore[arg-type]
513507
return messages

0 commit comments

Comments
 (0)