Skip to content

Commit 6a5f17d

Browse files
committed
revert evaluation.py changes after analysing the discrepancy in is_secondhand f1 score
1 parent 1450143 commit 6a5f17d

File tree

1 file changed

+134
-76
lines changed
  • multimodal/vl2l/src/mlperf_inference_multimodal_vl2l

1 file changed

+134
-76
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/evaluation.py

Lines changed: 134 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
import json
6+
import os
7+
from concurrent.futures import ProcessPoolExecutor
68
from pathlib import Path
79
from typing import TYPE_CHECKING
810

911
import numpy as np
1012
from datasets import load_dataset
11-
from hiclass.metrics import f1 # type: ignore[import-untyped]
1213
from loguru import logger
1314
from pydantic import ValidationError
1415
from rapidfuzz import fuzz # type: ignore[import-untyped]
@@ -22,11 +23,12 @@
2223

2324
from .schema import ProductMetadata
2425

25-
_TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
2626
_PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
2727
_PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
2828
_CATEGORY_SEPARATOR = " > "
2929

30+
_WORKER_CONTEXT = {}
31+
_MAX_JOBS = 4
3032

3133
def get_hierarchical_components(
3234
predicted_path: str,
@@ -110,7 +112,6 @@ def calculate_hierarchical_f1(
110112

111113
return 0.0 if hp + hr == 0 else 2 * (hp * hr) / (hp + hr)
112114

113-
114115
def calculate_brand_f1_score(data: list[tuple[str, str]]) -> float:
115116
"""Calculate the F1 score of brand field.
116117
@@ -141,7 +142,6 @@ def calculate_brand_f1_score(data: list[tuple[str, str]]) -> float:
141142
# For 1-to-1 extraction, Accuracy = Recall = Micro F1
142143
return sum(matches) / len(matches)
143144

144-
145145
def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
146146
"""Calculate F1 score of is_secondhand field.
147147
@@ -159,77 +159,54 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
159159
return f1_score(y_src, y_pred)
160160

161161

162-
def calculate_hiclass_f1(
163-
data: list[tuple[str, str]],
164-
separator: str = _CATEGORY_SEPARATOR,
165-
) -> float:
166-
"""Alt method to calculate hierarchical F1.
162+
def _process_chunk_rnd_brand(args: tuple[str, dict, dict]) -> tuple[str, str]:
163+
"""Function to process only chunks for random brand predictions.
167164
168165
Args:
169-
data: List of tuples of predicted and true values
170-
separator: The separator used to split the paths into levels of the category.
171-
172-
Returs:
173-
f1 score
166+
args: Tuple containing
174167
"""
175-
y_pred_raw = []
176-
y_true_raw = []
168+
pred_brand, elem, data_source = args
169+
# We pass the specific data row needed, or the whole structure if efficient
170+
return (pred_brand, data_source[elem["qsl_idx"]]["ground_truth_brand"])
177171

178-
for pred, src in data:
179-
path1 = pred.split(separator)
180-
path2 = src.split(separator)
181-
182-
y_pred_raw.append(path1)
183-
y_true_raw.append(path2)
184-
185-
# 2. Find the global maximum length across ALL samples
186-
# We check the longest path in both true and pred lists
187-
max_len = max(len(p) for p in y_true_raw + y_pred_raw)
188-
189-
# 3. Pad all lists to the global max_len
190-
for i in range(len(y_true_raw)):
191-
# Pad Truth
192-
pad_len_true = max_len - len(y_true_raw[i])
193-
y_true_raw[i] += [_TRUE_CATEGORY_PAD] * pad_len_true
172+
def init_worker(dataset: dict) -> None:
173+
"""Initialize worker data to process each chunk.
194174
195-
# Pad Prediction
196-
pad_len_pred = max_len - len(y_pred_raw[i])
197-
y_pred_raw[i] += [_PRED_CATEGORY_PAD] * pad_len_pred
175+
Args:
176+
dataset: huggingface dataset
177+
"""
178+
_WORKER_CONTEXT["dataset"] = dataset
198179

199-
# 4. Convert to numpy arrays
200-
y_true = np.array(y_true_raw)
201-
y_pred = np.array(y_pred_raw)
180+
def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
181+
"""Retrieve relevant information from each chunk of data.
202182
203-
# 5. Calculate Score
204-
return f1(y_true, y_pred)
183+
Args:
184+
args: Tuple that contains chunk of data and seed
205185
186+
Returns:
187+
Object with processed information
188+
"""
189+
chunk_data, seed = args
206190

207-
def run_evaluation(random_seed: int, filename: FilePath,
208-
dataset: DatasetCLI) -> None:
209-
"""Main function to run the evaluation."""
210-
rng = np.random.default_rng(seed=random_seed)
211-
with Path.open(filename) as f:
212-
model_output = json.load(f)
191+
# 1. Access the global dataset
192+
dataset = _WORKER_CONTEXT["dataset"]
213193

214-
original_data = load_dataset(
215-
dataset.repo_id,
216-
token=dataset.token,
217-
split="+".join(dataset.split),
218-
)
194+
# 2. Create a local, reproducible RNG for this specific chunk
195+
local_rng = np.random.default_rng(seed)
219196

220197
num_unparsable_responses = 0
221198
category_dataset_pred_src = []
222199
category_rand_pred_src = []
223200
is_secondhand_pred_src = []
224201
is_secondhand_rand_pred_src = []
225202
brand_pred_src = []
226-
227203
all_possible_brands = set()
204+
error_messages = []
228205

229-
for elem in model_output:
206+
for elem in chunk_data:
230207
idx = elem["qsl_idx"]
231208
response = bytes.fromhex(elem["data"]).decode("utf-8")
232-
ground_truth_item = original_data[idx]
209+
ground_truth_item = dataset[idx]
233210
all_possible_brands.add(ground_truth_item["ground_truth_brand"])
234211
try:
235212
pred_item = ProductMetadata.model_validate_json(response)
@@ -245,14 +222,14 @@ def run_evaluation(random_seed: int, filename: FilePath,
245222
),
246223
),
247224
brand=_PRED_BRAND_PAD,
248-
is_secondhand=rng.choice([True, False], size=1).tolist()[0],
225+
is_secondhand=local_rng.choice([True, False], size=1).tolist()[0],
249226
)
250-
logger.error(
251-
"Response\n{}\n(for the sample at index {}) cannot be validated against"
252-
" the expected schema. Overwriting this response into \n{}\n",
253-
response,
254-
idx,
255-
pred_item,
227+
error_messages.append(
228+
(
229+
f"Response\n{response}\n(for the sample at index {idx})"
230+
f"cannot be validated against"
231+
f" the expected schema. Overwriting this response into \n{pred_item}\n",
232+
),
256233
)
257234
category_dataset_pred_src.append(
258235
(pred_item.category, ground_truth_item["ground_truth_category"]),
@@ -268,35 +245,119 @@ def run_evaluation(random_seed: int, filename: FilePath,
268245
)
269246
# random category selection
270247
# Uniform distribution is the default
271-
rand_cat = rng.choice(
248+
rand_cat = local_rng.choice(
272249
ground_truth_item["potential_product_categories"])
273250
category_rand_pred_src.append(
274251
(rand_cat, ground_truth_item["ground_truth_category"]),
275252
)
276253
# random is_secondhand selection
277-
rand_is_secondhand = rng.choice([True, False])
254+
rand_is_secondhand = local_rng.choice([True, False])
278255
is_secondhand_rand_pred_src.append(
279256
(rand_is_secondhand,
280257
ground_truth_item["ground_truth_is_secondhand"]),
281258
)
282259

260+
return {
261+
"num_unparsable_responses": num_unparsable_responses,
262+
"error_messages": error_messages,
263+
"category_dataset_pred_src": category_dataset_pred_src,
264+
"category_rand_pred_src": category_rand_pred_src,
265+
"is_secondhand_pred_src": is_secondhand_pred_src,
266+
"is_secondhand_rand_pred_src": is_secondhand_rand_pred_src,
267+
"brand_pred_src": brand_pred_src,
268+
"all_possible_brands": list(all_possible_brands),
269+
}
270+
271+
272+
273+
def run_evaluation(random_seed: int, filename: FilePath,
274+
dataset: DatasetCLI) -> None:
275+
"""Main function to run the evaluation."""
276+
master_rng = np.random.default_rng(seed=random_seed)
277+
with Path.open(filename) as f:
278+
model_output = json.load(f)
279+
280+
original_data = load_dataset(
281+
dataset.repo_id,
282+
token=dataset.token,
283+
split="+".join(dataset.split),
284+
)
285+
286+
# get number of available CPU and get chunk size
287+
cpu_count = min(os.cpu_count() or 1, _MAX_JOBS)
288+
chunk_size = max(len(model_output) // cpu_count, 1)
289+
# Create chunks
290+
output_chunks = [
291+
model_output[i : i + chunk_size]
292+
for i in range(0, len(model_output), chunk_size)
293+
]
294+
295+
# Generate Seeds
296+
# One seed per chunk to ensure reproducibility.
297+
# The master_rng generates these,
298+
# so the whole run is deterministic based on `random_seed`.
299+
chunk_seeds = master_rng.integers(0, 2**32, size=len(output_chunks))
300+
301+
# Zip them: Each task is ([model_out_1, ...], 12345)
302+
tasks = zip(output_chunks, chunk_seeds, strict=False)
303+
304+
num_unparsable_responses = 0
305+
err_messages = []
306+
category_dataset_pred_src = []
307+
category_rand_pred_src = []
308+
is_secondhand_pred_src = []
309+
is_secondhand_rand_pred_src = []
310+
brand_pred_src = []
311+
all_possible_brands = []
312+
313+
with ProcessPoolExecutor(
314+
max_workers=cpu_count,
315+
initializer=init_worker,
316+
initargs=(original_data,),
317+
) as executor:
318+
# Execute
319+
chunk_results = list(executor.map(_process_chunk, tasks))
320+
321+
for chunk in chunk_results:
322+
num_unparsable_responses += chunk["num_unparsable_responses"]
323+
err_messages.extend(chunk["error_messages"])
324+
category_dataset_pred_src.extend(chunk["category_dataset_pred_src"])
325+
category_rand_pred_src.extend(chunk["category_rand_pred_src"])
326+
is_secondhand_pred_src.extend(chunk["is_secondhand_pred_src"])
327+
is_secondhand_rand_pred_src.extend(chunk["is_secondhand_rand_pred_src"])
328+
brand_pred_src.extend(chunk["brand_pred_src"])
329+
all_possible_brands.extend(chunk["all_possible_brands"])
330+
331+
for err in err_messages:
332+
logger.error("{}", err)
333+
283334
category_f1_score = calculate_hierarchical_f1(category_dataset_pred_src)
284-
hiclass_f1_score = calculate_hiclass_f1(category_dataset_pred_src)
285335
is_secondhand_f1_score = calculate_secondhand_f1(is_secondhand_pred_src)
286336
brand_score = calculate_brand_f1_score(brand_pred_src)
287337

288338
rand_cat_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
289-
rand_hiclass_f1_score = calculate_hiclass_f1(category_rand_pred_src)
339+
290340
rand_is_seconhand_f1_score = calculate_secondhand_f1(
291341
is_secondhand_rand_pred_src)
342+
343+
344+
all_brands_list = list(set(all_possible_brands))
345+
random_brand_predictions = master_rng.choice(
346+
all_brands_list,
347+
size=len(model_output))
348+
349+
args_list = (
350+
(pred, elem, original_data)
351+
for pred, elem in zip(random_brand_predictions, model_output, strict=False)
352+
)
353+
354+
with ProcessPoolExecutor() as executor:
355+
rand_brand_data = list(executor.map(_process_chunk_rnd_brand,
356+
args_list,
357+
chunksize=chunk_size))
358+
292359
rand_brand_score = calculate_brand_f1_score(
293-
[
294-
(
295-
rng.choice(list(all_possible_brands)),
296-
original_data[elem["qsl_idx"]]["ground_truth_brand"],
297-
)
298-
for elem in model_output
299-
],
360+
rand_brand_data,
300361
)
301362

302363
logger.info(
@@ -307,25 +368,22 @@ def run_evaluation(random_seed: int, filename: FilePath,
307368
[
308369
"From accuracy file",
309370
category_f1_score,
310-
hiclass_f1_score,
311371
brand_score,
312372
is_secondhand_f1_score,
313373
],
314374
[
315375
"Random selection",
316376
rand_cat_f1_score,
317-
rand_hiclass_f1_score,
318377
rand_brand_score,
319378
rand_is_seconhand_f1_score,
320379
],
321380
],
322381
headers=[
323382
"Results",
324383
"Category hierarchical F1 Score",
325-
"Category HiClass F1 Score",
326384
"Brand F1 Score",
327385
"Is_secondhand F1 Score",
328386
],
329387
tablefmt="fancy_grid",
330388
),
331-
)
389+
)

0 commit comments

Comments
 (0)