33from __future__ import annotations
44
55import json
6- import os
7- from concurrent .futures import ProcessPoolExecutor
86from pathlib import Path
97from typing import TYPE_CHECKING
108
119import numpy as np
1210from datasets import load_dataset
11+ from hiclass .metrics import f1 # type: ignore[import-untyped]
1312from loguru import logger
1413from pydantic import ValidationError
1514from rapidfuzz import fuzz # type: ignore[import-untyped]
2322
2423from .schema import ProductMetadata
2524
25+ _TRUE_CATEGORY_PAD = "<|__TRUE_CATEGORY_PAD__|>"
2626_PRED_CATEGORY_PAD = "<|__PRED_CATEGORY_PAD__|>"
2727_PRED_BRAND_PAD = "<|__PRED_BRAND_PAD__|>"
2828_CATEGORY_SEPARATOR = " > "
2929
30- _WORKER_CONTEXT = {}
31- _MAX_JOBS = 4
32-
3330
3431def get_hierarchical_components (
3532 predicted_path : str ,
@@ -162,56 +159,77 @@ def calculate_secondhand_f1(data: list[tuple[bool, bool]]) -> float:
162159 return f1_score (y_src , y_pred )
163160
164161
165- def _process_chunk_rnd_brand (args : tuple [str , dict , dict ]) -> tuple [str , str ]:
166- """Function to process only chunks for random brand predictions.
162+ def calculate_hiclass_f1 (
163+ data : list [tuple [str , str ]],
164+ separator : str = _CATEGORY_SEPARATOR ,
165+ ) -> float :
166+ """Alt method to calculate hierarchical F1.
167167
168168 Args:
169- args: Tuple containing
169+ data: List of tuples of predicted and true values
170+ separator: The separator used to split the paths into levels of the category.
171+
172+ Returs:
173+ f1 score
170174 """
171- pred_brand , elem , data_source = args
172- # We pass the specific data row needed, or the whole structure if efficient
173- return (pred_brand , data_source [elem ["qsl_idx" ]]["ground_truth_brand" ])
175+ y_pred_raw = []
176+ y_true_raw = []
174177
178+ for pred , src in data :
179+ path1 = pred .split (separator )
180+ path2 = src .split (separator )
175181
176- def init_worker ( dataset : dict ) -> None :
177- """Initialize worker data to process each chunk.
182+ y_pred_raw . append ( path1 )
183+ y_true_raw . append ( path2 )
178184
179- Args:
180- dataset: huggingface dataset
181- """
182- _WORKER_CONTEXT ["dataset" ] = dataset
185+ # 2. Find the global maximum length across ALL samples
186+ # We check the longest path in both true and pred lists
187+ max_len = max (len (p ) for p in y_true_raw + y_pred_raw )
183188
189+ # 3. Pad all lists to the global max_len
190+ for i in range (len (y_true_raw )):
191+ # Pad Truth
192+ pad_len_true = max_len - len (y_true_raw [i ])
193+ y_true_raw [i ] += [_TRUE_CATEGORY_PAD ] * pad_len_true
184194
185- def _process_chunk (args : tuple [list [dict ], int ]) -> dict [str , any ]:
186- """Retrieve relevant information from each chunk of data.
195+ # Pad Prediction
196+ pad_len_pred = max_len - len (y_pred_raw [i ])
197+ y_pred_raw [i ] += [_PRED_CATEGORY_PAD ] * pad_len_pred
187198
188- Args:
189- args: Tuple that contains chunk of data and seed
199+ # 4. Convert to numpy arrays
200+ y_true = np .array (y_true_raw )
201+ y_pred = np .array (y_pred_raw )
202+
203+ # 5. Calculate Score
204+ return f1 (y_true , y_pred )
190205
191- Returns:
192- Object with processed information
193- """
194- chunk_data , seed = args
195206
196- # 1. Access the global dataset
197- dataset = _WORKER_CONTEXT ["dataset" ]
207+ def run_evaluation (random_seed : int , filename : FilePath ,
208+ dataset : DatasetCLI ) -> None :
209+ """Main function to run the evaluation."""
210+ rng = np .random .default_rng (seed = random_seed )
211+ with Path .open (filename ) as f :
212+ model_output = json .load (f )
198213
199- # 2. Create a local, reproducible RNG for this specific chunk
200- local_rng = np .random .default_rng (seed )
214+ original_data = load_dataset (
215+ dataset .repo_id ,
216+ token = dataset .token ,
217+ split = "+" .join (dataset .split ),
218+ )
201219
202220 num_unparsable_responses = 0
203221 category_dataset_pred_src = []
204222 category_rand_pred_src = []
205223 is_secondhand_pred_src = []
206224 is_secondhand_rand_pred_src = []
207225 brand_pred_src = []
226+
208227 all_possible_brands = set ()
209- error_messages = []
210228
211- for elem in chunk_data :
229+ for elem in model_output :
212230 idx = elem ["qsl_idx" ]
213231 response = bytes .fromhex (elem ["data" ]).decode ("utf-8" )
214- ground_truth_item = dataset [idx ]
232+ ground_truth_item = original_data [idx ]
215233 all_possible_brands .add (ground_truth_item ["ground_truth_brand" ])
216234 try :
217235 pred_item = ProductMetadata .model_validate_json (response )
@@ -227,15 +245,14 @@ def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
227245 ),
228246 ),
229247 brand = _PRED_BRAND_PAD ,
230- is_secondhand = local_rng .choice (
231- [True , False ], size = 1 ).tolist ()[0 ],
248+ is_secondhand = rng .choice ([True , False ], size = 1 ).tolist ()[0 ],
232249 )
233- error_messages . append (
234- (
235- f"Response \n { response } \n (for the sample at index { idx } )"
236- f"cannot be validated against"
237- f" the expected schema. Overwriting this response into \n { pred_item } \n " ,
238- ) ,
250+ logger . error (
251+ "Response \n {} \n (for the sample at index {}) cannot be validated against"
252+ " the expected schema. Overwriting this response into \n {} \n " ,
253+ response ,
254+ idx ,
255+ pred_item ,
239256 )
240257 category_dataset_pred_src .append (
241258 (pred_item .category , ground_truth_item ["ground_truth_category" ]),
@@ -251,118 +268,35 @@ def _process_chunk(args: tuple[list[dict], int]) -> dict[str, any]:
251268 )
252269 # random category selection
253270 # Uniform distribution is the default
254- rand_cat = local_rng .choice (
271+ rand_cat = rng .choice (
255272 ground_truth_item ["potential_product_categories" ])
256273 category_rand_pred_src .append (
257274 (rand_cat , ground_truth_item ["ground_truth_category" ]),
258275 )
259276 # random is_secondhand selection
260- rand_is_secondhand = local_rng .choice ([True , False ])
277+ rand_is_secondhand = rng .choice ([True , False ])
261278 is_secondhand_rand_pred_src .append (
262279 (rand_is_secondhand ,
263280 ground_truth_item ["ground_truth_is_secondhand" ]),
264281 )
265282
266- return {
267- "num_unparsable_responses" : num_unparsable_responses ,
268- "error_messages" : error_messages ,
269- "category_dataset_pred_src" : category_dataset_pred_src ,
270- "category_rand_pred_src" : category_rand_pred_src ,
271- "is_secondhand_pred_src" : is_secondhand_pred_src ,
272- "is_secondhand_rand_pred_src" : is_secondhand_rand_pred_src ,
273- "brand_pred_src" : brand_pred_src ,
274- "all_possible_brands" : list (all_possible_brands ),
275- }
276-
277-
278- def run_evaluation (random_seed : int , filename : FilePath ,
279- dataset : DatasetCLI ) -> None :
280- """Main function to run the evaluation."""
281- master_rng = np .random .default_rng (seed = random_seed )
282- with Path .open (filename ) as f :
283- model_output = json .load (f )
284-
285- original_data = load_dataset (
286- dataset .repo_id ,
287- token = dataset .token ,
288- split = "+" .join (dataset .split ),
289- )
290-
291- # get number of available CPU and get chunk size
292- cpu_count = min (os .cpu_count () or 1 , _MAX_JOBS )
293- chunk_size = max (len (model_output ) // cpu_count , 1 )
294- # Create chunks
295- output_chunks = [
296- model_output [i : i + chunk_size ]
297- for i in range (0 , len (model_output ), chunk_size )
298- ]
299-
300- # Generate Seeds
301- # One seed per chunk to ensure reproducibility.
302- # The master_rng generates these,
303- # so the whole run is deterministic based on `random_seed`.
304- chunk_seeds = master_rng .integers (0 , 2 ** 32 , size = len (output_chunks ))
305-
306- # Zip them: Each task is ([model_out_1, ...], 12345)
307- tasks = zip (output_chunks , chunk_seeds , strict = False )
308-
309- num_unparsable_responses = 0
310- err_messages = []
311- category_dataset_pred_src = []
312- category_rand_pred_src = []
313- is_secondhand_pred_src = []
314- is_secondhand_rand_pred_src = []
315- brand_pred_src = []
316- all_possible_brands = []
317-
318- with ProcessPoolExecutor (
319- max_workers = cpu_count ,
320- initializer = init_worker ,
321- initargs = (original_data ,),
322- ) as executor :
323- # Execute
324- chunk_results = list (executor .map (_process_chunk , tasks ))
325-
326- for chunk in chunk_results :
327- num_unparsable_responses += chunk ["num_unparsable_responses" ]
328- err_messages .extend (chunk ["error_messages" ])
329- category_dataset_pred_src .extend (chunk ["category_dataset_pred_src" ])
330- category_rand_pred_src .extend (chunk ["category_rand_pred_src" ])
331- is_secondhand_pred_src .extend (chunk ["is_secondhand_pred_src" ])
332- is_secondhand_rand_pred_src .extend (
333- chunk ["is_secondhand_rand_pred_src" ])
334- brand_pred_src .extend (chunk ["brand_pred_src" ])
335- all_possible_brands .extend (chunk ["all_possible_brands" ])
336-
337- for err in err_messages :
338- logger .error ("{}" , err )
339-
340283 category_f1_score = calculate_hierarchical_f1 (category_dataset_pred_src )
284+ hiclass_f1_score = calculate_hiclass_f1 (category_dataset_pred_src )
341285 is_secondhand_f1_score = calculate_secondhand_f1 (is_secondhand_pred_src )
342286 brand_score = calculate_brand_f1_score (brand_pred_src )
343287
344288 rand_cat_f1_score = calculate_hierarchical_f1 (category_rand_pred_src )
345-
289+ rand_hiclass_f1_score = calculate_hiclass_f1 ( category_rand_pred_src )
346290 rand_is_seconhand_f1_score = calculate_secondhand_f1 (
347291 is_secondhand_rand_pred_src )
348-
349- all_brands_list = list (set (all_possible_brands ))
350- random_brand_predictions = master_rng .choice (
351- all_brands_list ,
352- size = len (model_output ))
353-
354- args_list = (
355- (pred , elem , original_data )
356- for pred , elem in zip (random_brand_predictions , model_output , strict = False )
357- )
358-
359- with ProcessPoolExecutor () as executor :
360- rand_brand_data = list (executor .map (_process_chunk_rnd_brand ,
361- args_list ,
362- chunksize = chunk_size ))
363-
364292 rand_brand_score = calculate_brand_f1_score (
365- rand_brand_data ,
293+ [
294+ (
295+ rng .choice (list (all_possible_brands )),
296+ original_data [elem ["qsl_idx" ]]["ground_truth_brand" ],
297+ )
298+ for elem in model_output
299+ ],
366300 )
367301
368302 logger .info (
@@ -373,22 +307,25 @@ def run_evaluation(random_seed: int, filename: FilePath,
373307 [
374308 "From accuracy file" ,
375309 category_f1_score ,
310+ hiclass_f1_score ,
376311 brand_score ,
377312 is_secondhand_f1_score ,
378313 ],
379314 [
380315 "Random selection" ,
381316 rand_cat_f1_score ,
317+ rand_hiclass_f1_score ,
382318 rand_brand_score ,
383319 rand_is_seconhand_f1_score ,
384320 ],
385321 ],
386322 headers = [
387323 "Results" ,
388324 "Category hierarchical F1 Score" ,
325+ "Category HiClass F1 Score" ,
389326 "Brand F1 Score" ,
390327 "Is_secondhand F1 Score" ,
391328 ],
392329 tablefmt = "fancy_grid" ,
393330 ),
394- )
331+ )
0 commit comments