Skip to content

Commit 7378c1d

Browse files
committed
make the metrics work for ft as well
1 parent c3592a1 commit 7378c1d

1 file changed

Lines changed: 8 additions & 12 deletions

File tree

scripts/get_max_eval_metrics_from_wandb.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from olmoearth_pretrain.train.callbacks.evaluator_callback import EvalMode
1717

1818
WANDB_ENTITY = "eai-ai2"
19-
METRICS = list(EVAL_TASKS.keys())
2019

2120
# Dataset partitions to consider (excluding default)
2221
PARTITIONS = [
@@ -375,21 +374,18 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
375374
help="Aggregate metrics per dataset partition instead of grouping by '_step'",
376375
)
377376
parser.add_argument(
378-
"--get_test_metrics",
377+
"--finetune",
379378
action="store_true",
380-
help="Report test metrics based on the configuration of the validation results witht the highest score",
379+
help="Use finetune evaluation tasks when determining metrics",
381380
)
382381
parser.add_argument(
383-
"--finetune",
382+
"--get_test_metrics",
384383
action="store_true",
385-
help="Use finetune evaluation tasks when determining metrics",
384+
help="Report test metrics based on the configuration of the validation results witht the highest score",
386385
)
387386

388387
args = parser.parse_args()
389-
390-
global METRICS
391-
selected_tasks = FT_EVAL_TASKS if args.finetune else EVAL_TASKS
392-
METRICS = list(selected_tasks.keys())
388+
metrics = list(FT_EVAL_TASKS.keys()) if args.finetune else list(EVAL_TASKS.keys())
393389

394390
if args.per_partition:
395391
if not args.run_prefix:
@@ -404,7 +400,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
404400
for partition in PARTITIONS:
405401
if partition in partition_metrics:
406402
print(f"\n{partition}:")
407-
for metric in METRICS:
403+
for metric in metrics:
408404
# Try original name
409405
key = f"eval/{metric}"
410406
val = partition_metrics[partition].get(key)
@@ -445,7 +441,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
445441
print("\nFinal Results:")
446442
for group_name, metrics in group_metrics.items():
447443
print(f"\n{group_name}:")
448-
for metric in METRICS:
444+
for metric in metrics:
449445
try:
450446
k = f"eval/{metric}"
451447
print(f" {metric}: {metrics[k]}")
@@ -460,7 +456,7 @@ def save_metrics_to_csv(metrics_dict: dict[str, dict[str, float]], filename: str
460456
print("\nFinal Test Results:")
461457
for group_name, metrics in group_test_metrics.items():
462458
print(f"\n{group_name}:")
463-
for metric in METRICS:
459+
for metric in metrics:
464460
try:
465461
k = f"eval/test/{metric}"
466462
print(f" {metric}: {metrics[k]}")

0 commit comments

Comments
 (0)