33from __future__ import annotations
44
55import json
6+ import os
7+ from concurrent .futures import ProcessPoolExecutor
68from pathlib import Path
79from typing import TYPE_CHECKING
810
911import numpy as np
1012from datasets import load_dataset
11- from hiclass .metrics import f1 # type: ignore[import-untyped]
1213from loguru import logger
1314from pydantic import ValidationError
1415from rapidfuzz import fuzz # type: ignore[import-untyped]
2223
2324from .schema import ProductMetadata
2425
25- _TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
2626_PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
2727_PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
2828_CATEGORY_SEPARATOR = " > "
2929
30+ _WORKER_CONTEXT = {}
31+ _MAX_JOBS = 4
3032
3133def get_hierarchical_components (
3234 predicted_path : str ,
@@ -110,7 +112,6 @@ def calculate_hierarchical_f1(
110112
111113 return 0.0 if hp + hr == 0 else 2 * (hp * hr ) / (hp + hr )
112114
113-
114115def calculate_brand_f1_score (data : list [tuple [str , str ]]) -> float :
115116 """Calculate the F1 score of brand field.
116117
@@ -141,7 +142,6 @@ def calculate_brand_f1_score(data: list[tuple[str, str]]) -> float:
141142 # For 1-to-1 extraction, Accuracy = Recall = Micro F1
142143 return sum (matches ) / len (matches )
143144
144-
145145def calculate_secondhand_f1 (data : list [tuple [bool , bool ]]) -> float :
146146 """Calculate F1 score of is_secondhand field.
147147
@@ -159,77 +159,54 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
159159 return f1_score (y_src , y_pred )
160160
161161
162- def calculate_hiclass_f1 (
163- data : list [tuple [str , str ]],
164- separator : str = _CATEGORY_SEPARATOR ,
165- ) -> float :
166- """Alt method to calculate hierarchical F1.
162+ def _process_chunk_rnd_brand (args : tuple [str , dict , dict ]) -> tuple [str , str ]:
163+ """Function to process only chunks for random brand predictions.
167164
168165 Args:
169- data: List of tuples of predicted and true values
170- separator: The separator used to split the paths into levels of the category.
171-
172- Returs:
173- f1 score
166+ args: Tuple containing
174167 """
175- y_pred_raw = []
176- y_true_raw = []
168+ pred_brand , elem , data_source = args
169+ # We pass the specific data row needed, or the whole structure if efficient
170+ return (pred_brand , data_source [elem ["qsl_idx" ]]["ground_truth_brand" ])
177171
178- for pred , src in data :
179- path1 = pred .split (separator )
180- path2 = src .split (separator )
181-
182- y_pred_raw .append (path1 )
183- y_true_raw .append (path2 )
184-
185- # 2. Find the global maximum length across ALL samples
186- # We check the longest path in both true and pred lists
187- max_len = max (len (p ) for p in y_true_raw + y_pred_raw )
188-
189- # 3. Pad all lists to the global max_len
190- for i in range (len (y_true_raw )):
191- # Pad Truth
192- pad_len_true = max_len - len (y_true_raw [i ])
193- y_true_raw [i ] += [_TRUE_CATEGORY_PAD ] * pad_len_true
172+ def init_worker (dataset : dict ) -> None :
173+ """Initialize worker data to process each chunk.
194174
195- # Pad Prediction
196- pad_len_pred = max_len - len (y_pred_raw [i ])
197- y_pred_raw [i ] += [_PRED_CATEGORY_PAD ] * pad_len_pred
175+ Args:
176+ dataset: huggingface dataset
177+ """
178+ _WORKER_CONTEXT ["dataset" ] = dataset
198179
199- # 4. Convert to numpy arrays
200- y_true = np .array (y_true_raw )
201- y_pred = np .array (y_pred_raw )
180+ def _process_chunk (args : tuple [list [dict ], int ]) -> dict [str , any ]:
181+ """Retrieve relevant information from each chunk of data.
202182
203- # 5. Calculate Score
204- return f1 ( y_true , y_pred )
183+ Args:
184+ args: Tuple that contains chunk of data and seed
205185
186+ Returns:
187+ Object with processed information
188+ """
189+ chunk_data , seed = args
206190
207- def run_evaluation (random_seed : int , filename : FilePath ,
208- dataset : DatasetCLI ) -> None :
209- """Main function to run the evaluation."""
210- rng = np .random .default_rng (seed = random_seed )
211- with Path .open (filename ) as f :
212- model_output = json .load (f )
191+ # 1. Access the global dataset
192+ dataset = _WORKER_CONTEXT ["dataset" ]
213193
214- original_data = load_dataset (
215- dataset .repo_id ,
216- token = dataset .token ,
217- split = "+" .join (dataset .split ),
218- )
194+ # 2. Create a local, reproducible RNG for this specific chunk
195+ local_rng = np .random .default_rng (seed )
219196
220197 num_unparsable_responses = 0
221198 category_dataset_pred_src = []
222199 category_rand_pred_src = []
223200 is_secondhand_pred_src = []
224201 is_secondhand_rand_pred_src = []
225202 brand_pred_src = []
226-
227203 all_possible_brands = set ()
204+ error_messages = []
228205
229- for elem in model_output :
206+ for elem in chunk_data :
230207 idx = elem ["qsl_idx" ]
231208 response = bytes .fromhex (elem ["data" ]).decode ("utf-8" )
232- ground_truth_item = original_data [idx ]
209+ ground_truth_item = dataset [idx ]
233210 all_possible_brands .add (ground_truth_item ["ground_truth_brand" ])
234211 try :
235212 pred_item = ProductMetadata .model_validate_json (response )
@@ -245,14 +222,14 @@ def run_evaluation(random_seed: int, filename: FilePath,
245222 ),
246223 ),
247224 brand = _PRED_BRAND_PAD ,
248- is_secondhand = rng .choice ([True , False ], size = 1 ).tolist ()[0 ],
225+ is_secondhand = local_rng .choice ([True , False ], size = 1 ).tolist ()[0 ],
249226 )
250- logger . error (
251- "Response \n {} \n (for the sample at index {}) cannot be validated against"
252- " the expected schema. Overwriting this response into \n {} \n " ,
253- response ,
254- idx ,
255- pred_item ,
227+ error_messages . append (
228+ (
229+ f"Response \n { response } \n (for the sample at index { idx } )"
230+ f"cannot be validated against"
231+ f" the expected schema. Overwriting this response into \n { pred_item } \n " ,
232+ ) ,
256233 )
257234 category_dataset_pred_src .append (
258235 (pred_item .category , ground_truth_item ["ground_truth_category" ]),
@@ -268,35 +245,119 @@ def run_evaluation(random_seed: int, filename: FilePath,
268245 )
269246 # random category selection
270247 # Uniform distribution is the default
271- rand_cat = rng .choice (
248+ rand_cat = local_rng .choice (
272249 ground_truth_item ["potential_product_categories" ])
273250 category_rand_pred_src .append (
274251 (rand_cat , ground_truth_item ["ground_truth_category" ]),
275252 )
276253 # random is_secondhand selection
277- rand_is_secondhand = rng .choice ([True , False ])
254+ rand_is_secondhand = local_rng .choice ([True , False ])
278255 is_secondhand_rand_pred_src .append (
279256 (rand_is_secondhand ,
280257 ground_truth_item ["ground_truth_is_secondhand" ]),
281258 )
282259
260+ return {
261+ "num_unparsable_responses" : num_unparsable_responses ,
262+ "error_messages" : error_messages ,
263+ "category_dataset_pred_src" : category_dataset_pred_src ,
264+ "category_rand_pred_src" : category_rand_pred_src ,
265+ "is_secondhand_pred_src" : is_secondhand_pred_src ,
266+ "is_secondhand_rand_pred_src" : is_secondhand_rand_pred_src ,
267+ "brand_pred_src" : brand_pred_src ,
268+ "all_possible_brands" : list (all_possible_brands ),
269+ }
270+
271+
272+
273+ def run_evaluation (random_seed : int , filename : FilePath ,
274+ dataset : DatasetCLI ) -> None :
275+ """Main function to run the evaluation."""
276+ master_rng = np .random .default_rng (seed = random_seed )
277+ with Path .open (filename ) as f :
278+ model_output = json .load (f )
279+
280+ original_data = load_dataset (
281+ dataset .repo_id ,
282+ token = dataset .token ,
283+ split = "+" .join (dataset .split ),
284+ )
285+
286+ # get number of available CPU and get chunk size
287+ cpu_count = min (os .cpu_count () or 1 , _MAX_JOBS )
288+ chunk_size = max (len (model_output ) // cpu_count , 1 )
289+ # Create chunks
290+ output_chunks = [
291+ model_output [i : i + chunk_size ]
292+ for i in range (0 , len (model_output ), chunk_size )
293+ ]
294+
295+ # Generate Seeds
296+ # One seed per chunk to ensure reproducibility.
297+ # The master_rng generates these,
298+ # so the whole run is deterministic based on `random_seed`.
299+ chunk_seeds = master_rng .integers (0 , 2 ** 32 , size = len (output_chunks ))
300+
301+ # Zip them: Each task is ([model_out_1, ...], 12345)
302+ tasks = zip (output_chunks , chunk_seeds , strict = False )
303+
304+ num_unparsable_responses = 0
305+ err_messages = []
306+ category_dataset_pred_src = []
307+ category_rand_pred_src = []
308+ is_secondhand_pred_src = []
309+ is_secondhand_rand_pred_src = []
310+ brand_pred_src = []
311+ all_possible_brands = []
312+
313+ with ProcessPoolExecutor (
314+ max_workers = cpu_count ,
315+ initializer = init_worker ,
316+ initargs = (original_data ,),
317+ ) as executor :
318+ # Execute
319+ chunk_results = list (executor .map (_process_chunk , tasks ))
320+
321+ for chunk in chunk_results :
322+ num_unparsable_responses += chunk ["num_unparsable_responses" ]
323+ err_messages .extend (chunk ["error_messages" ])
324+ category_dataset_pred_src .extend (chunk ["category_dataset_pred_src" ])
325+ category_rand_pred_src .extend (chunk ["category_rand_pred_src" ])
326+ is_secondhand_pred_src .extend (chunk ["is_secondhand_pred_src" ])
327+ is_secondhand_rand_pred_src .extend (chunk ["is_secondhand_rand_pred_src" ])
328+ brand_pred_src .extend (chunk ["brand_pred_src" ])
329+ all_possible_brands .extend (chunk ["all_possible_brands" ])
330+
331+ for err in err_messages :
332+ logger .error ("{}" , err )
333+
283334 category_f1_score = calculate_hierarchical_f1 (category_dataset_pred_src )
284- hiclass_f1_score = calculate_hiclass_f1 (category_dataset_pred_src )
285335 is_secondhand_f1_score = calculate_secondhand_f1 (is_secondhand_pred_src )
286336 brand_score = calculate_brand_f1_score (brand_pred_src )
287337
288338 rand_cat_f1_score = calculate_hierarchical_f1 (category_rand_pred_src )
289- rand_hiclass_f1_score = calculate_hiclass_f1 ( category_rand_pred_src )
339+
290340 rand_is_seconhand_f1_score = calculate_secondhand_f1 (
291341 is_secondhand_rand_pred_src )
342+
343+
344+ all_brands_list = list (set (all_possible_brands ))
345+ random_brand_predictions = master_rng .choice (
346+ all_brands_list ,
347+ size = len (model_output ))
348+
349+ args_list = (
350+ (pred , elem , original_data )
351+ for pred , elem in zip (random_brand_predictions , model_output , strict = False )
352+ )
353+
354+ with ProcessPoolExecutor () as executor :
355+ rand_brand_data = list (executor .map (_process_chunk_rnd_brand ,
356+ args_list ,
357+ chunksize = chunk_size ))
358+
292359 rand_brand_score = calculate_brand_f1_score (
293- [
294- (
295- rng .choice (list (all_possible_brands )),
296- original_data [elem ["qsl_idx" ]]["ground_truth_brand" ],
297- )
298- for elem in model_output
299- ],
360+ rand_brand_data ,
300361 )
301362
302363 logger .info (
@@ -307,25 +368,22 @@ def run_evaluation(random_seed: int, filename: FilePath,
307368 [
308369 "From accuracy file" ,
309370 category_f1_score ,
310- hiclass_f1_score ,
311371 brand_score ,
312372 is_secondhand_f1_score ,
313373 ],
314374 [
315375 "Random selection" ,
316376 rand_cat_f1_score ,
317- rand_hiclass_f1_score ,
318377 rand_brand_score ,
319378 rand_is_seconhand_f1_score ,
320379 ],
321380 ],
322381 headers = [
323382 "Results" ,
324383 "Category hierarchical F1 Score" ,
325- "Category HiClass F1 Score" ,
326384 "Brand F1 Score" ,
327385 "Is_secondhand F1 Score" ,
328386 ],
329387 tablefmt = "fancy_grid" ,
330388 ),
331- )
389+ )
0 commit comments