Skip to content

Commit d10d634

Browse files
committed
revert evaluation changeS
1 parent 7576e0c commit d10d634

File tree

1 file changed

+72
-135
lines changed
  • multimodal/vl2l/src/mlperf_inference_multimodal_vl2l

1 file changed

+72
-135
lines changed

multimodal/vl2l/src/mlperf_inference_multimodal_vl2l/evaluation.py

Lines changed: 72 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
from __future__ import annotations
44

55
import json
6-
import os
7-
from concurrent.futures import ProcessPoolExecutor
86
from pathlib import Path
97
from typing import TYPE_CHECKING
108

119
import numpy as np
1210
from datasets import load_dataset
11+
from hiclass.metrics import f1 # type: ignore[import-untyped]
1312
from loguru import logger
1413
from pydantic import ValidationError
1514
from rapidfuzz import fuzz # type: ignore[import-untyped]
@@ -23,13 +22,11 @@
2322

2423
from .schema import ProductMetadata
2524

25+
_TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
2626
_PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
2727
_PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
2828
_CATEGORY_SEPARATOR = " > "
2929

30-
_WORKER_CONTEXT = {}
31-
_MAX_JOBS = 4
32-
3330

3431
def get_hierarchical_components(
3532
predicted_path: str,
@@ -162,56 +159,77 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
162159
return f1_score(y_src, y_pred)
163160

164161

165-
def _process_chunk_rnd_brand(args: tuple[str, dict, dict]) -> tuple[str, str]:
166-
"""Function to process only chunks for random brand predictions.
162+
def calculate_hiclass_f1(
163+
data: list[tuple[str, str]],
164+
separator: str = _CATEGORY_SEPARATOR,
165+
) -> float:
166+
"""Alt method to calculate hierarchical F1.
167167
168168
Args:
169-
args: Tuple containing
169+
data: List of tuples of predicted and true values
170+
separator: The separator used to split the paths into levels of the category.
171+
172+
Returs:
173+
f1 score
170174
"""
171-
pred_brand, elem, data_source = args
172-
# We pass the specific data row needed, or the whole structure if efficient
173-
return (pred_brand, data_source[elem["qsl_idx"]]["ground_truth_brand"])
175+
y_pred_raw = []
176+
y_true_raw = []
174177

178+
for pred, src in data:
179+
path1 = pred.split(separator)
180+
path2 = src.split(separator)
175181

176-
def init_worker(dataset: dict) -> None:
177-
"""Initialize worker data to process each chunk.
182+
y_pred_raw.append(path1)
183+
y_true_raw.append(path2)
178184

179-
Args:
180-
dataset: huggingface dataset
181-
"""
182-
_WORKER_CONTEXT["dataset"] = dataset
185+
# 2. Find the global maximum length across ALL samples
186+
# We check the longest path in both true and pred lists
187+
max_len = max(len(p) for p in y_true_raw + y_pred_raw)
183188

189+
# 3. Pad all lists to the global max_len
190+
for i in range(len(y_true_raw)):
191+
# Pad Truth
192+
pad_len_true = max_len - len(y_true_raw[i])
193+
y_true_raw[i] += [_TRUE_CATEGORY_PAD] * pad_len_true
184194

185-
def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
186-
"""Retrieve relevant information from each chunk of data.
195+
# Pad Prediction
196+
pad_len_pred = max_len - len(y_pred_raw[i])
197+
y_pred_raw[i] += [_PRED_CATEGORY_PAD] * pad_len_pred
187198

188-
Args:
189-
args: Tuple that contains chunk of data and seed
199+
# 4. Convert to numpy arrays
200+
y_true = np.array(y_true_raw)
201+
y_pred = np.array(y_pred_raw)
202+
203+
# 5. Calculate Score
204+
return f1(y_true, y_pred)
190205

191-
Returns:
192-
Object with processed information
193-
"""
194-
chunk_data, seed = args
195206

196-
# 1. Access the global dataset
197-
dataset = _WORKER_CONTEXT["dataset"]
207+
def run_evaluation(random_seed: int, filename: FilePath,
208+
dataset: DatasetCLI) -> None:
209+
"""Main function to run the evaluation."""
210+
rng = np.random.default_rng(seed=random_seed)
211+
with Path.open(filename) as f:
212+
model_output = json.load(f)
198213

199-
# 2. Create a local, reproducible RNG for this specific chunk
200-
local_rng = np.random.default_rng(seed)
214+
original_data = load_dataset(
215+
dataset.repo_id,
216+
token=dataset.token,
217+
split="+".join(dataset.split),
218+
)
201219

202220
num_unparsable_responses = 0
203221
category_dataset_pred_src = []
204222
category_rand_pred_src = []
205223
is_secondhand_pred_src = []
206224
is_secondhand_rand_pred_src = []
207225
brand_pred_src = []
226+
208227
all_possible_brands = set()
209-
error_messages = []
210228

211-
for elem in chunk_data:
229+
for elem in model_output:
212230
idx = elem["qsl_idx"]
213231
response = bytes.fromhex(elem["data"]).decode("utf-8")
214-
ground_truth_item = dataset[idx]
232+
ground_truth_item = original_data[idx]
215233
all_possible_brands.add(ground_truth_item["ground_truth_brand"])
216234
try:
217235
pred_item = ProductMetadata.model_validate_json(response)
@@ -227,15 +245,14 @@ def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
227245
),
228246
),
229247
brand=_PRED_BRAND_PAD,
230-
is_secondhand=local_rng.choice(
231-
[True, False], size=1).tolist()[0],
248+
is_secondhand=rng.choice([True, False], size=1).tolist()[0],
232249
)
233-
error_messages.append(
234-
(
235-
f"Response\n{response}\n(for the sample at index {idx})"
236-
f"cannot be validated against"
237-
f" the expected schema. Overwriting this response into \n{pred_item}\n",
238-
),
250+
logger.error(
251+
"Response\n{}\n(for the sample at index {}) cannot be validated against"
252+
" the expected schema. Overwriting this response into \n{}\n",
253+
response,
254+
idx,
255+
pred_item,
239256
)
240257
category_dataset_pred_src.append(
241258
(pred_item.category, ground_truth_item["ground_truth_category"]),
@@ -251,118 +268,35 @@ def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
251268
)
252269
# random category selection
253270
# Uniform distribution is the default
254-
rand_cat = local_rng.choice(
271+
rand_cat = rng.choice(
255272
ground_truth_item["potential_product_categories"])
256273
category_rand_pred_src.append(
257274
(rand_cat, ground_truth_item["ground_truth_category"]),
258275
)
259276
# random is_secondhand selection
260-
rand_is_secondhand = local_rng.choice([True, False])
277+
rand_is_secondhand = rng.choice([True, False])
261278
is_secondhand_rand_pred_src.append(
262279
(rand_is_secondhand,
263280
ground_truth_item["ground_truth_is_secondhand"]),
264281
)
265282

266-
return {
267-
"num_unparsable_responses": num_unparsable_responses,
268-
"error_messages": error_messages,
269-
"category_dataset_pred_src": category_dataset_pred_src,
270-
"category_rand_pred_src": category_rand_pred_src,
271-
"is_secondhand_pred_src": is_secondhand_pred_src,
272-
"is_secondhand_rand_pred_src": is_secondhand_rand_pred_src,
273-
"brand_pred_src": brand_pred_src,
274-
"all_possible_brands": list(all_possible_brands),
275-
}
276-
277-
278-
def run_evaluation(random_seed: int, filename: FilePath,
279-
dataset: DatasetCLI) -> None:
280-
"""Main function to run the evaluation."""
281-
master_rng = np.random.default_rng(seed=random_seed)
282-
with Path.open(filename) as f:
283-
model_output = json.load(f)
284-
285-
original_data = load_dataset(
286-
dataset.repo_id,
287-
token=dataset.token,
288-
split="+".join(dataset.split),
289-
)
290-
291-
# get number of available CPU and get chunk size
292-
cpu_count = min(os.cpu_count() or 1, _MAX_JOBS)
293-
chunk_size = max(len(model_output) // cpu_count, 1)
294-
# Create chunks
295-
output_chunks = [
296-
model_output[i: i + chunk_size]
297-
for i in range(0, len(model_output), chunk_size)
298-
]
299-
300-
# Generate Seeds
301-
# One seed per chunk to ensure reproducibility.
302-
# The master_rng generates these,
303-
# so the whole run is deterministic based on `random_seed`.
304-
chunk_seeds = master_rng.integers(0, 2**32, size=len(output_chunks))
305-
306-
# Zip them: Each task is ([model_out_1, ...], 12345)
307-
tasks = zip(output_chunks, chunk_seeds, strict=False)
308-
309-
num_unparsable_responses = 0
310-
err_messages = []
311-
category_dataset_pred_src = []
312-
category_rand_pred_src = []
313-
is_secondhand_pred_src = []
314-
is_secondhand_rand_pred_src = []
315-
brand_pred_src = []
316-
all_possible_brands = []
317-
318-
with ProcessPoolExecutor(
319-
max_workers=cpu_count,
320-
initializer=init_worker,
321-
initargs=(original_data,),
322-
) as executor:
323-
# Execute
324-
chunk_results = list(executor.map(_process_chunk, tasks))
325-
326-
for chunk in chunk_results:
327-
num_unparsable_responses += chunk["num_unparsable_responses"]
328-
err_messages.extend(chunk["error_messages"])
329-
category_dataset_pred_src.extend(chunk["category_dataset_pred_src"])
330-
category_rand_pred_src.extend(chunk["category_rand_pred_src"])
331-
is_secondhand_pred_src.extend(chunk["is_secondhand_pred_src"])
332-
is_secondhand_rand_pred_src.extend(
333-
chunk["is_secondhand_rand_pred_src"])
334-
brand_pred_src.extend(chunk["brand_pred_src"])
335-
all_possible_brands.extend(chunk["all_possible_brands"])
336-
337-
for err in err_messages:
338-
logger.error("{}", err)
339-
340283
category_f1_score = calculate_hierarchical_f1(category_dataset_pred_src)
284+
hiclass_f1_score = calculate_hiclass_f1(category_dataset_pred_src)
341285
is_secondhand_f1_score = calculate_secondhand_f1(is_secondhand_pred_src)
342286
brand_score = calculate_brand_f1_score(brand_pred_src)
343287

344288
rand_cat_f1_score = calculate_hierarchical_f1(category_rand_pred_src)
345-
289+
rand_hiclass_f1_score = calculate_hiclass_f1(category_rand_pred_src)
346290
rand_is_seconhand_f1_score = calculate_secondhand_f1(
347291
is_secondhand_rand_pred_src)
348-
349-
all_brands_list = list(set(all_possible_brands))
350-
random_brand_predictions = master_rng.choice(
351-
all_brands_list,
352-
size=len(model_output))
353-
354-
args_list = (
355-
(pred, elem, original_data)
356-
for pred, elem in zip(random_brand_predictions, model_output, strict=False)
357-
)
358-
359-
with ProcessPoolExecutor() as executor:
360-
rand_brand_data = list(executor.map(_process_chunk_rnd_brand,
361-
args_list,
362-
chunksize=chunk_size))
363-
364292
rand_brand_score = calculate_brand_f1_score(
365-
rand_brand_data,
293+
[
294+
(
295+
rng.choice(list(all_possible_brands)),
296+
original_data[elem["qsl_idx"]]["ground_truth_brand"],
297+
)
298+
for elem in model_output
299+
],
366300
)
367301

368302
logger.info(
@@ -373,22 +307,25 @@ def run_evaluation(random_seed: int, filename: FilePath,
373307
[
374308
"From accuracy file",
375309
category_f1_score,
310+
hiclass_f1_score,
376311
brand_score,
377312
is_secondhand_f1_score,
378313
],
379314
[
380315
"Random selection",
381316
rand_cat_f1_score,
317+
rand_hiclass_f1_score,
382318
rand_brand_score,
383319
rand_is_seconhand_f1_score,
384320
],
385321
],
386322
headers=[
387323
"Results",
388324
"Category hierarchical F1 Score",
325+
"Category HiClass F1 Score",
389326
"Brand F1 Score",
390327
"Is_secondhand F1 Score",
391328
],
392329
tablefmt="fancy_grid",
393330
),
394-
)
331+
)

0 commit comments

Comments
 (0)