1616from olmoearth_pretrain .train .callbacks .evaluator_callback import EvalMode
1717
1818WANDB_ENTITY = "eai-ai2"
19- METRICS = list (EVAL_TASKS .keys ())
2019
2120# Dataset partitions to consider (excluding default)
2221PARTITIONS = [
@@ -375,21 +374,18 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
375374 help = "Aggregate metrics per dataset partition instead of grouping by '_step'" ,
376375 )
377376 parser .add_argument (
378- "--get_test_metrics " ,
377+ "--finetune " ,
379378 action = "store_true" ,
380- help = "Report test metrics based on the configuration of the validation results witht the highest score " ,
379+ help = "Use finetune evaluation tasks when determining metrics " ,
381380 )
382381 parser .add_argument (
383- "--finetune " ,
382+ "--get_test_metrics " ,
384383 action = "store_true" ,
385- help = "Use finetune evaluation tasks when determining metrics " ,
384+ help = "Report test metrics based on the configuration of the validation results witht the highest score " ,
386385 )
387386
388387 args = parser .parse_args ()
389-
390- global METRICS
391- selected_tasks = FT_EVAL_TASKS if args .finetune else EVAL_TASKS
392- METRICS = list (selected_tasks .keys ())
388+ metrics = list (FT_EVAL_TASKS .keys ()) if args .finetune else list (EVAL_TASKS .keys ())
393389
394390 if args .per_partition :
395391 if not args .run_prefix :
@@ -404,7 +400,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
404400 for partition in PARTITIONS :
405401 if partition in partition_metrics :
406402 print (f"\n { partition } :" )
407- for metric in METRICS :
403+ for metric in metrics :
408404 # Try original name
409405 key = f"eval/{ metric } "
410406 val = partition_metrics [partition ].get (key )
@@ -445,7 +441,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
445441 print ("\n Final Results:" )
446442 for group_name , metrics in group_metrics .items ():
447443 print (f"\n { group_name } :" )
448- for metric in METRICS :
444+ for metric in metrics :
449445 try :
450446 k = f"eval/{ metric } "
451447 print (f" { metric } : { metrics [k ]} " )
@@ -460,7 +456,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
460456 print ("\n Final Test Results:" )
461457 for group_name , metrics in group_test_metrics .items ():
462458 print (f"\n { group_name } :" )
463- for metric in METRICS :
459+ for metric in metrics :
464460 try :
465461 k = f"eval/test/{ metric } "
466462 print (f" { metric } : { metrics [k ]} " )
0 commit comments