Skip to content

Commit 0dc13dc

Browse files
committed
additional fixes to reviews
1 parent 46e4fc0 commit 0dc13dc

File tree

3 files changed

+103
-56
lines changed

3 files changed

+103
-56
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/cli.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .task import ShopifyGlobalCatalogue
1919

2020
app = Typer()
21-
21+
SplitType = Literal["train", "test"]
2222

2323
class TestScenario(StrEnum):
2424
"""The test scenario for the MLPerf inference LoadGen."""
@@ -378,9 +378,33 @@ class Dataset(BaseModel):
378378
] = None
379379

380380
split: Annotated[
381-
Literal["train", "test"],
382-
Field(description="choose between train or test split"),
383-
] = "train"
381+
list[str],
382+
Field(
383+
description=(
384+
"List of splits in order (e.g. ['train', 'test']). "
385+
"Allowed values: 'train', 'test'."
386+
),
387+
),
388+
] = ["train"]
389+
390+
@field_validator("split", mode="before")
391+
@classmethod
392+
def normalize_and_validate_split(cls, v: str) -> list[str]:
393+
"""Normalize and validate the input string of field split."""
394+
# Allow a single string like "train" or "train,test"
395+
if isinstance(v, str):
396+
v = [part.strip() for part in v.split(",") if part.strip()]
397+
398+
if not isinstance(v, list):
399+
err="split must be a string or a list of strings"
400+
raise TypeError(err)
401+
402+
allowed = {"train", "test"}
403+
for item in v:
404+
if item not in allowed:
405+
msg_err = f"Invalid split {item!r}. Must be one of: {sorted(allowed)}"
406+
raise ValueError(msg_err)
407+
return v
384408

385409

386410
class Verbosity(StrEnum):

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/evaluation.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Task definitions for the VL2L benchmark."""
22

3+
from __future__ import annotations
4+
35
import json
46
from pathlib import Path
57
from typing import TYPE_CHECKING
@@ -8,11 +10,12 @@
810
from datasets import load_dataset
911
from hiclass.metrics import f1
1012
from loguru import logger
11-
from pydantic import FilePath
1213
from sklearn.metrics import f1_score
1314
from tabulate import tabulate
1415

1516
if TYPE_CHECKING:
17+
from pydantic import FilePath
18+
1619
from .cli import Dataset as DatasetCLI
1720

1821

@@ -58,8 +61,8 @@ def get_hierarchical_components(predicted_path: str,
5861
return intersection_count, pred_length, true_length
5962

6063

61-
def calculate_hierarchical_metrics(data: list[tuple[str, str]]) -> float:
62-
"""Calculates the aggregate hP, hR, and hF scores for a list of samples.
64+
def calculate_hierarchical_f1(data: list[tuple[str, str]]) -> float:
65+
"""Calculates the aggregate hF scores for a list of samples.
6366
6467
Args:
6568
data: A list of tuples, where each tuple is
@@ -107,9 +110,30 @@ def calculate_exact_match(generated_text: str, original_text: str) -> float:
107110

108111
return 1.0 if gen == orig else 0.0
109112

113+
def calculate_secondhand_f1(data: list[tuple[str, str]]) -> float:
114+
"""Calculate F1 score of is_secondhand field.
115+
116+
Args:
117+
data: List of tuples of predicted and true values
118+
Returs:
119+
f1 score
120+
"""
121+
y_pred = []
122+
y_src = []
123+
for pred, src in data:
124+
y_pred.append(pred)
125+
y_src.append(src)
126+
127+
return f1_score(y_src, y_pred)
128+
129+
def calculate_hiclass_f1(data: list[tuple[str, str]]) -> float:
130+
"""Alt method to calculate hierarchical F1.
110131
111-
def alt_f1_score(data: list[tuple[str, str]]) -> float:
112-
"""Alt method to calculate hierarchical F1."""
132+
Args:
133+
data: List of tuples of predicted and true values
134+
Returs:
135+
f1 score
136+
"""
113137
y_pred_raw = []
114138
y_true_raw = []
115139

@@ -142,19 +166,19 @@ def alt_f1_score(data: list[tuple[str, str]]) -> float:
142166
return f1(y_true, y_pred)
143167

144168

145-
def run_evaluation(filename: FilePath, dataset: "DatasetCLI") -> None:
169+
def run_evaluation(filename: FilePath, dataset: DatasetCLI) -> None:
146170
"""Main function to run the evaluation."""
147171
with Path.open(filename) as f:
148172
model_output = json.load(f)
149173

150174
original_data = load_dataset(
151175
dataset.repo_id,
152176
dataset.token,
153-
)[dataset.split]
177+
split="+".join(dataset.split),
178+
)
154179

155180
category_dataset_pred_src = []
156-
is_secondhand_pred = []
157-
is_secondhand_src = []
181+
is_secondhand_pred_src = []
158182
for elem in model_output:
159183
byte_data = bytes.fromhex(elem["data"])
160184
idx = elem["qsl_idx"]
@@ -163,15 +187,13 @@ def run_evaluation(filename: FilePath, dataset: "DatasetCLI") -> None:
163187
ground_truth_item = original_data[idx]
164188
category_dataset_pred_src.append((pred_item["category"],
165189
ground_truth_item["ground_truth_category"]))
166-
is_secondhand_pred.append(int(pred_item["is_secondhand"]))
167-
is_secondhand_src.append(
168-
int(ground_truth_item["ground_truth_is_secondhand"]))
190+
is_secondhand_pred_src.append((int(pred_item["is_secondhand"]),
191+
int(ground_truth_item["ground_truth_is_secondhand"])))
169192

170-
category_f1_score = calculate_hierarchical_metrics(
193+
category_f1_score = calculate_hierarchical_f1(
171194
category_dataset_pred_src)
172-
hiclass_f1 = alt_f1_score(category_dataset_pred_src)
173-
is_secondhand_f1_score = f1_score(is_secondhand_src,
174-
is_secondhand_pred)
195+
hiclass_f1 = calculate_hiclass_f1(category_dataset_pred_src)
196+
is_secondhand_f1_score = calculate_secondhand_f1(is_secondhand_pred_src)
175197

176198
data = [
177199
["category", category_f1_score, hiclass_f1],

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/task.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,17 @@ def __init__(
4646
dataset_cli: The dataset configuration passed in from the CLI.
4747
model_cli: The model configuration passed in from the CLI.
4848
endpoint_cli: The endpoint configuration passed in from the CLI.
49-
settings: Parameters of the current benchmark
49+
settings: Parameters of the current benchmark.
5050
random_seed: The random seed to use for the task.
5151
"""
5252
random.seed(random_seed)
5353
self.scenario = settings.scenario
5454
self.dataset = load_dataset(
5555
dataset_cli.repo_id,
5656
token=dataset_cli.token,
57+
split="+".join(dataset_cli.split),
5758
)
59+
logger.info(f"LEN: {len(self.dataset)}")
5860
self.model_cli = model_cli
5961
self.openai_api_client = AsyncOpenAI(
6062
base_url=endpoint_cli.url,
@@ -377,7 +379,7 @@ def __init__(
377379
dataset_cli: The dataset configuration passed in from the CLI.
378380
model_cli: The model configuration passed in from the CLI.
379381
endpoint_cli: The endpoint configuration passed in from the CLI.
380-
settings: Parameters of the current benchmark
382+
settings: Parameters of the current benchmark.
381383
random_seed: The random seed to use for the task.
382384
"""
383385
super().__init__(
@@ -387,8 +389,6 @@ def __init__(
387389
settings=settings,
388390
random_seed=random_seed,
389391
)
390-
# Shopify only released the train split so far.
391-
self.dataset = self.dataset[dataset_cli.split]
392392

393393
@staticmethod
394394
def formulate_messages(
@@ -409,44 +409,45 @@ def formulate_messages(
409409
return [
410410
{
411411
"role": "system",
412-
"content": """
413-
Please analyze the product from the user prompt
414-
and provide the following fields in a valid JSON object:
415-
- category
416-
- brand
417-
- is_secondhand
418-
419-
You must choose only one, which is the most appropriate/correct,
420-
category out of the list of possible product categories.
421-
422-
Your response should only contain a valid JSON object and nothing more.
423-
The JSON object should match the followng JSON schema:
424-
```json
425-
{
426-
"type": "object",
427-
"properties": {
428-
"category": {"type": "string"},
429-
"brand": {"type": "string"},
430-
"is_secondhand": {"type": "boolean"}
431-
}
432-
}
433-
```
434-
""",
412+
"content": """Please analyze the product from the user prompt
413+
and provide the following fields in a valid JSON object:
414+
- category
415+
- brand
416+
- is_secondhand
417+
You must choose only one, which is the most appropriate/correct,
418+
category out of the list of possible product categories.
419+
Your response should only contain a valid JSON object and nothing more.
420+
The JSON object should match the followng JSON schema:
421+
```json
422+
{
423+
"type": "object",
424+
"properties": {
425+
"category": {"type": "string"},
426+
"brand": {"type": "string"},
427+
"is_secondhand": {"type": "boolean"}
428+
}
429+
}
430+
```
431+
""",
435432
},
436433
{
437434
"role": "user",
438435
"content": [
439436
{
440437
"type": "text",
441-
"text": (
442-
f"The title of the product is: {sample['product_title']}\n"
443-
f"The description of the product is: "
444-
f"{sample['product_description']}\n\n",
445-
"These are the possible product categories: ",
446-
f"{sample['potential_product_categories']}.",
447-
"You must choose only one and return the answer"
448-
" as string and not as a list",
449-
),
438+
"text": f"""The title of the product is the following:
439+
```text
440+
{sample['product_title']}
441+
```
442+
The description of the product is the following:
443+
```text
444+
{sample['product_description']}
445+
```
446+
The following are the possible product categories:
447+
```json
448+
{sample['potential_product_categories']}
449+
```
450+
""",
450451
},
451452
{
452453
"type": "image_url",

0 commit comments

Comments
 (0)