Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 135 additions & 50 deletions scripts/tools/get_max_eval_metrics_from_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
def get_run_group_name(run_name: str) -> str:
"""Extracts the group name from a run name, e.g., 'my_experiment_step_100' -> 'my_experiment'."""
# just split on _step and take the first part
return run_name.split("_step")[0]
if "_step" in run_name:
return run_name.split("_step")[0]
elif "_dataset" in run_name:
return run_name.split("_dataset")[0]
elif "_pre_trained" in run_name:
return run_name.split("_pre_trained")[0]
else:
raise ValueError(f"unexpected run name {run_name}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change because we now have very long run names. This means I need to assign the model a new name during evals, but I still want to be able to get all the runs in an eval project



def get_run_groups(
Expand Down Expand Up @@ -82,7 +89,9 @@ def group_runs_by_run_prefix_and_step(
for run in api.runs(wandb_path, lazy=False):
if run_prefix and not run.name.startswith(run_prefix):
continue
group_name = get_run_group_name(run.name) if "step" in run.name else run_prefix
group_name = get_run_group_name(
run.name
) # if "step" in run.name else run_prefix
grouped_runs[group_name].append(run)
print(f"Found run {run.name} ({run.id}) -> group: {group_name}")
return grouped_runs
Expand Down Expand Up @@ -128,6 +137,17 @@ def _get_corresponding_test_key(key: str) -> str:
return key.replace("eval/", "eval/test/")


def _normalize_eval_key(key: str) -> str:
"""Normalize eval_other/ keys back to eval/ keys for consistent CSV columns.

eval_other/{task}/{metric} -> eval/{task}/{metric}
eval_other/test/{task}/{metric} -> eval/test/{task}/{metric}
"""
if key.startswith("eval_other/"):
return key.replace("eval_other/", "eval/", 1)
return key


def get_max_metrics_grouped(
grouped_runs: dict[str, list[wandb.Run]],
get_test_metrics: bool = False,
Expand All @@ -148,24 +168,53 @@ def get_max_metrics_grouped(
for run in runs:
for key, value in run.summary.items():
# TODO: Make these metrics names constants
if not key.startswith("eval/"):
# Accept both eval/ and eval_other/ keys
if not (key.startswith("eval/") or key.startswith("eval_other/")):
continue
if key.startswith("eval/test/"):
print(
f"Skipping test metric {key} for run {run.name} because it is a test metric"
)
# DO NOT select on test metrics
# Skip test metrics (in both namespaces)
if key.startswith("eval/test/") or key.startswith("eval_other/test/"):
continue
# Ensure the run has test metrics
if run.summary.get(_get_corresponding_test_key(key), None) is None:
# DOn't select top val metrics if there is no corresponding test metric
print(
f"Skipping metric {key} for run {run.name} because it has no corresponding test metric"

# Normalize eval_other/ keys to eval/ for consistent column names
normalized_key = _normalize_eval_key(key)

# For post-PR#504 runs, eval/{task} contains the primary metric value
# (e.g., micro_f1 for mados). Also store it under eval/{task}/{primary_metric}
# so it lands in the correct sub-metric column alongside pre-PR#504 data.
parts = normalized_key.split("/")
task_name = parts[1]
additional_key = None
if len(parts) == 2:
# This is a primary-only key like eval/{task}
task_config_for_primary = (
run.config.get("trainer", {})
.get("callbacks", {})
.get("downstream_evaluator", {})
.get("tasks", {})
.get(task_name, {})
)
pm = task_config_for_primary.get("primary_metric", None)
if pm is not None:
# Also store as eval/{task}/{primary_metric} for consistent sub-metric columns.
# The config stores the enum name (e.g. "MICRO_F1") but metric
# keys use the enum value (e.g. "micro_f1"), so lowercase it.
additional_key = f"{normalized_key}/{pm.lower()}"

# Ensure the run has test metrics (check both namespaces).
# For post-PR#504 runs, the primary test metric is at eval/test/{task}
# while sub-metrics are at eval_other/test/{task}/{metric}.
test_key_primary = f"eval/test/{task_name}"
test_key = _get_corresponding_test_key(normalized_key)
has_test = (
run.summary.get(test_key) is not None
or run.summary.get(test_key.replace("eval/", "eval_other/", 1))
is not None
or run.summary.get(test_key_primary) is not None
)
if not has_test:
continue

# If for the given metric, it is a linear probe task skip if it was not done with early stop linear porbing
task_name = key.split("/")[1]
# If for the given metric, it is a linear probe task skip if it was not done with early stop linear probing
task_config = run.config["trainer"]["callbacks"][
"downstream_evaluator"
]["tasks"][task_name]
Expand All @@ -182,19 +231,27 @@ def get_max_metrics_grouped(
"select_final_test_miou_based_on_epoch_of_max_val_miou", False
),
)
if is_linear_probe_task and not is_select_best_by_primary_metric:
if (
is_linear_probe_task
and get_test_metrics
and not is_select_best_by_primary_metric
):
print(
f"Skipping metric {key} for run {run.name} because it is a linear probe task but not done with early stop linear probing"
f"Skipping metric {normalized_key} for run {run.name} because it is a linear probe task but not done with early stop linear probing"
)
continue
print(
f"Selecting metric {key} for run {run.name} because it matches criteria"
)

prev_max_val = metrics.get(key, float("-inf"))
metrics[key] = max(prev_max_val, value)
prev_max_val = metrics.get(normalized_key, float("-inf"))
metrics[normalized_key] = max(prev_max_val, value)
if value > prev_max_val:
max_runs_per_metric[key] = run
max_runs_per_metric[normalized_key] = run

# Also record under the explicit sub-metric key
if additional_key is not None:
prev_max_val = metrics.get(additional_key, float("-inf"))
metrics[additional_key] = max(prev_max_val, value)
if value > prev_max_val:
max_runs_per_metric[additional_key] = run

group_metrics[group_name] = metrics
group_max_runs_per_metric[group_name] = max_runs_per_metric
Expand All @@ -207,8 +264,19 @@ def get_max_metrics_grouped(
for group_name, max_runs_per_metric in group_max_runs_per_metric.items():
test_metrics = {}
for metric, run in max_runs_per_metric.items():
# metric is already normalized to eval/ namespace (e.g. eval/mados/micro_f1)
test_metric_key = metric.replace("eval/", "eval/test/")
# Check both eval/ and eval_other/ namespaces for test metrics.
# For post-PR#504 runs with remapped primary keys (eval/{task}/{primary_metric}),
# the actual test value may be at eval/test/{task} (the primary).
value = run.summary.get(test_metric_key, None)
if value is None:
alt_key = test_metric_key.replace("eval/", "eval_other/", 1)
value = run.summary.get(alt_key, None)
if value is None:
# Try the primary test key (eval/test/{task})
task_name = metric.split("/")[1]
value = run.summary.get(f"eval/test/{task_name}", None)
if value is None:
print(
f"No test metric found for run {run.name} for metric {metric}"
Expand Down Expand Up @@ -271,10 +339,11 @@ def get_max_metrics_per_partition(
for run_id in run_ids:
run = api.run(f"{WANDB_ENTITY}/{project_name}/{run_id}")
for key, value in run.summary.items():
if not key.startswith("eval/"):
if not (key.startswith("eval/") or key.startswith("eval_other/")):
continue
partition_max_metrics[key] = max(
partition_max_metrics.get(key, value), value
normalized_key = _normalize_eval_key(key)
partition_max_metrics[normalized_key] = max(
partition_max_metrics.get(normalized_key, value), value
)

partition_metrics[partition] = partition_max_metrics
Expand Down Expand Up @@ -312,9 +381,10 @@ def get_max_metrics(project_name: str, run_prefix: str) -> dict[str, float]:
for run_id in run_ids:
run = api.run(f"{WANDB_ENTITY}/{project_name}/{run_id}")
for key, value in run.summary.items():
if not key.startswith("eval/"):
if not (key.startswith("eval/") or key.startswith("eval_other/")):
continue
metrics[key] = max(metrics.get(key, value), value)
normalized_key = _normalize_eval_key(key)
metrics[normalized_key] = max(metrics.get(normalized_key, value), value)
return metrics


Expand Down Expand Up @@ -445,14 +515,19 @@ def serialize_max_settings_per_group(
# Try original name
key = f"eval/{metric}"
val = partition_metrics[partition].get(key)
name_for_print = metric
# Fallback with underscore variant
if val is None:
metric_alt = metric.replace("-", "_")
key_alt = f"eval/{metric_alt}"
val = partition_metrics[partition].get(key_alt)
name_for_print = metric_alt if val is not None else metric
else:
name_for_print = metric
# also try the segmentation suffixes
if val is None:
metric_alt = f"{metric}/miou"
key_alt = f"eval/{metric_alt}"
val = partition_metrics[partition].get(key_alt)
name_for_print = metric_alt if val is not None else metric

if val is None:
print(f" {metric}: not found")
Expand Down Expand Up @@ -482,35 +557,45 @@ def serialize_max_settings_per_group(
serialize_max_settings_per_group(
args.json_filename, group_max_runs_per_metric
)

def _print_task_metrics(
metrics: dict[str, float], prefix: str, task_name: str
) -> None:
"""Print all sub-metrics for a task."""
task_name_alt = task_name.replace("-", "_")
task_prefix = f"{prefix}/{task_name}/"
task_prefix_alt = f"{prefix}/{task_name_alt}/"
sub_metrics = {
k: v
for k, v in metrics.items()
if (k.startswith(task_prefix) or k.startswith(task_prefix_alt))
and "/f1_class_" not in k # skip per-class f1 (too verbose)
}
if sub_metrics:
for k in sorted(sub_metrics):
# Extract just the metric name after the task
metric_name = k.split("/")[-1]
print(f" {task_name}/{metric_name}: {sub_metrics[k]}")
else:
# Fall back to eval/{task} (pre-PR#504 runs without sub-metrics)
for name in (task_name, task_name_alt):
k = f"{prefix}/{name}"
if k in metrics:
print(f" {name}: {metrics[k]}")
return
print(f" {task_name}: not found")

print("\nFinal Results:")
for group_name, metrics in group_metrics.items():
print(f"\n{group_name}:")
for metric in all_metrics:
try:
k = f"eval/{metric}"
print(f" {metric}: {metrics[k]}")
except KeyError:
try:
metric = metric.replace("-", "_")
k = f"eval/{metric}"
print(f" {metric}: {metrics[k]}")
except KeyError:
print(f" {metric}: not found")
_print_task_metrics(metrics, "eval", metric)
if args.get_test_metrics:
print("\nFinal Test Results:")
for group_name, metrics in group_test_metrics.items():
print(f"\n{group_name}:")
for metric in all_metrics:
try:
k = f"eval/test/{metric}"
print(f" {metric}: {metrics[k]}")
except KeyError:
try:
metric = metric.replace("-", "_")
k = f"eval/test/{metric}"
print(f" {metric}: {metrics[k]}")
except KeyError:
print(f" {metric}: not found")
_print_task_metrics(metrics, "eval/test", metric)

# Save to CSV
if args.output:
Expand Down
Loading