38
38
print_writeout ,
39
39
run_task_tests ,
40
40
)
41
- from lm_eval .logging_utils import add_env_info , get_git_commit_hash
41
+ from lm_eval .loggers . utils import add_env_info , get_git_commit_hash
42
42
from lm_eval .tasks import TaskManager , get_task_dict
43
43
from lm_eval .utils import eval_logger , positional_deprecated , simple_parse_args_string
44
44
from lm_eval import utils
@@ -509,9 +509,14 @@ def evaluate(
509
509
# aggregate results ; run bootstrap CIs
510
510
for task_output in eval_tasks :
511
511
task_output .calculate_aggregate_metric (bootstrap_iters = bootstrap_iters )
512
- results , samples , configs , versions , num_fewshot = consolidate_results (
513
- eval_tasks
514
- )
512
+ (
513
+ results ,
514
+ samples ,
515
+ configs ,
516
+ versions ,
517
+ num_fewshot ,
518
+ higher_is_better ,
519
+ ) = consolidate_results (eval_tasks )
515
520
516
521
### Calculate group metrics ###
517
522
if bool (results ):
@@ -522,6 +527,24 @@ def evaluate(
522
527
# or `task_name: []`.
523
528
# we only want to operate on groups here.
524
529
continue
530
+
531
+ # collect all higher_is_better values for metrics
532
+ # in the group's subtasks.
533
+ # TODO: clean this up ; unify with the below metric_list loop?
534
+ _higher_is_better = {}
535
+ for task in task_list :
536
+ for m , h in higher_is_better [task ].items ():
537
+ if m not in _higher_is_better .keys ():
538
+ _higher_is_better [m ] = h
539
+ if m in _higher_is_better and _higher_is_better [m ] is not None and _higher_is_better [m ] != h :
540
+ eval_logger .warning (
541
+ f"Higher_is_better values for metric { m } in group { group } are not consistent." +
542
+ f"Defaulting to None."
543
+ )
544
+ _higher_is_better [m ] = None
545
+ higher_is_better [group ] = _higher_is_better
546
+
547
+ # collect all metric keys used by a subtask in the group.
525
548
metric_list = list (
526
549
{
527
550
key
@@ -534,38 +557,22 @@ def evaluate(
534
557
stderr = "_stderr," .join (metric .split ("," ))
535
558
536
559
# gather metrics, sizes, and stderrs from subtasks
537
- metrics = [
538
- results [task ][metric ]
539
- for task in task_list
540
- if metric in results [task ]
541
- ] # TODO: copy?
542
- stderrs = [
543
- results [task ][stderr ]
544
- for task in task_list
545
- if stderr in results [task ]
546
- ]
547
- sizes = [
548
- results [task ]["samples" ]
549
- for task in task_list
550
- if metric in results [task ]
551
- ]
560
+ metrics = [results [task ][metric ] for task in task_list if metric in results [task ]] # TODO: copy?
561
+ stderrs = [results [task ][stderr ] for task in task_list if stderr in results [task ]]
562
+ sizes = [results [task ]["samples" ] for task in task_list if metric in results [task ]]
552
563
553
564
# compute group's pooled metric and stderr
554
- results [group ][metric ] = (
555
- lm_eval .api .metrics .aggregate_subtask_metrics (metrics , sizes )
556
- )
565
+ results [group ][metric ] = lm_eval .api .metrics .aggregate_subtask_metrics (metrics , sizes )
557
566
# TODO: calculate grouped metric using aggregation fn
558
567
if "N/A" in stderrs :
559
568
results [group ][stderr ] = "N/A"
560
569
else :
561
- results [group ][stderr ] = (
562
- lm_eval .api .metrics .pooled_sample_stderr (stderrs , sizes )
563
- )
570
+ results [group ][stderr ] = lm_eval .api .metrics .pooled_sample_stderr (stderrs , sizes )
564
571
# TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
565
572
# To use the old (likely incorrect) variance formula,
566
573
# comment out the above and uncomment this line:
567
- # results[group][stderr] = \
568
- # lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)
574
+ # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs,
575
+ # sizes, metrics=metrics)
569
576
570
577
results [group ]["samples" ] = sum (sizes )
571
578
@@ -578,19 +585,15 @@ def evaluate(
578
585
if len (left_tasks_list ) == 0 :
579
586
break
580
587
581
- _task_hierarchy = {
582
- k : v for k , v in task_hierarchy .items () if k in left_tasks_list
583
- }
588
+ _task_hierarchy = {k : v for k , v in task_hierarchy .items () if k in left_tasks_list }
584
589
_results_agg , _groups_agg = prepare_print_tasks (_task_hierarchy , results )
585
590
586
591
results_agg = {** results_agg , ** _results_agg }
587
592
groups_agg = {** groups_agg , ** _groups_agg }
588
593
589
594
for group_name , task_list in task_hierarchy .items ():
590
595
if task_list :
591
- num_fewshot [group_name ] = num_fewshot [
592
- task_list [0 ]
593
- ] # TODO: validate this
596
+ num_fewshot [group_name ] = num_fewshot [task_list [0 ]] # TODO: validate this
594
597
595
598
results_dict = {
596
599
"results" : dict (results_agg .items ()),
@@ -599,6 +602,17 @@ def evaluate(
599
602
"configs" : dict (sorted (configs .items ())),
600
603
"versions" : dict (sorted (versions .items ())),
601
604
"n-shot" : dict (sorted (num_fewshot .items ())),
605
+ "higher_is_better" : dict (sorted (higher_is_better .items ())),
606
+ "n-samples" : {
607
+ task_output .task_name : {
608
+ "original" : len (task_output .task .eval_docs ),
609
+ "effective" : min (
610
+ limit if limit else len (task_output .task .eval_docs ),
611
+ len (task_output .task .eval_docs ),
612
+ ),
613
+ }
614
+ for task_output in eval_tasks
615
+ },
602
616
}
603
617
if log_samples :
604
618
results_dict ["samples" ] = dict (samples )
@@ -608,7 +622,6 @@ def evaluate(
608
622
else :
609
623
return None
610
624
611
-
612
625
def request_caching_arg_to_dict (cache_requests : str ) -> dict :
613
626
request_caching_args = {
614
627
"cache_requests" : cache_requests in {"true" , "refresh" },
0 commit comments