diff --git a/scripts/tools/get_max_eval_metrics_from_wandb.py b/scripts/tools/get_max_eval_metrics_from_wandb.py index 4da96b4dd..24b4400e0 100644 --- a/scripts/tools/get_max_eval_metrics_from_wandb.py +++ b/scripts/tools/get_max_eval_metrics_from_wandb.py @@ -32,7 +32,14 @@ def get_run_group_name(run_name: str) -> str: """Extracts the group name from a run name, e.g., 'my_experiment_step_100' -> 'my_experiment'.""" # just split on _step and take the first part - return run_name.split("_step")[0] + if "_step" in run_name: + return run_name.split("_step")[0] + elif "_dataset" in run_name: + return run_name.split("_dataset")[0] + elif "_pre_trained" in run_name: + return run_name.split("_pre_trained")[0] + else: + raise ValueError(f"unexpected run name {run_name}") def get_run_groups( @@ -82,7 +89,9 @@ def group_runs_by_run_prefix_and_step( for run in api.runs(wandb_path, lazy=False): if run_prefix and not run.name.startswith(run_prefix): continue - group_name = get_run_group_name(run.name) if "step" in run.name else run_prefix + group_name = get_run_group_name( + run.name + ) # if "step" in run.name else run_prefix grouped_runs[group_name].append(run) print(f"Found run {run.name} ({run.id}) -> group: {group_name}") return grouped_runs @@ -128,6 +137,17 @@ def _get_corresponding_test_key(key: str) -> str: return key.replace("eval/", "eval/test/") +def _normalize_eval_key(key: str) -> str: + """Normalize eval_other/ keys back to eval/ keys for consistent CSV columns. + + eval_other/{task}/{metric} -> eval/{task}/{metric} + eval_other/test/{task}/{metric} -> eval/test/{task}/{metric} + """ + if key.startswith("eval_other/"): + return key.replace("eval_other/", "eval/", 1) + return key + + def get_max_metrics_grouped( grouped_runs: dict[str, list[wandb.Run]], get_test_metrics: bool = False, @@ -148,24 +168,53 @@ def get_max_metrics_grouped( for run in runs: for key, value in run.summary.items(): # TODO: Make these metrics names constants - if not key.startswith("eval/"): + # Accept both eval/ and eval_other/ keys + if not (key.startswith("eval/") or key.startswith("eval_other/")): continue - if key.startswith("eval/test/"): - print( - f"Skipping test metric {key} for run {run.name} because it is a test metric" - ) - # DO NOT select on test metrics + # Skip test metrics (in both namespaces) + if key.startswith("eval/test/") or key.startswith("eval_other/test/"): continue - # Ensure the run has test metrics - if run.summary.get(_get_corresponding_test_key(key), None) is None: - # DOn't select top val metrics if there is no corresponding test metric - print( - f"Skipping metric {key} for run {run.name} because it has no corresponding test metric" + + # Normalize eval_other/ keys to eval/ for consistent column names + normalized_key = _normalize_eval_key(key) + + # For post-PR#504 runs, eval/{task} contains the primary metric value + # (e.g., micro_f1 for mados). Also store it under eval/{task}/{primary_metric} + # so it lands in the correct sub-metric column alongside pre-PR#504 data. + parts = normalized_key.split("/") + task_name = parts[1] + additional_key = None + if len(parts) == 2: + # This is a primary-only key like eval/{task} + task_config_for_primary = ( + run.config.get("trainer", {}) + .get("callbacks", {}) + .get("downstream_evaluator", {}) + .get("tasks", {}) + .get(task_name, {}) ) + pm = task_config_for_primary.get("primary_metric", None) + if pm is not None: + # Also store as eval/{task}/{primary_metric} for consistent sub-metric columns. + # The config stores the enum name (e.g. "MICRO_F1") but metric + # keys use the enum value (e.g. "micro_f1"), so lowercase it. + additional_key = f"{normalized_key}/{pm.lower()}" + + # Ensure the run has test metrics (check both namespaces). + # For post-PR#504 runs, the primary test metric is at eval/test/{task} + # while sub-metrics are at eval_other/test/{task}/{metric}. + test_key_primary = f"eval/test/{task_name}" + test_key = _get_corresponding_test_key(normalized_key) + has_test = ( + run.summary.get(test_key) is not None + or run.summary.get(test_key.replace("eval/", "eval_other/", 1)) + is not None + or run.summary.get(test_key_primary) is not None + ) + if not has_test: continue - # If for the given metric, it is a linear probe task skip if it was not done with early stop linear porbing - task_name = key.split("/")[1] + # If for the given metric, it is a linear probe task skip if it was not done with early stop linear probing task_config = run.config["trainer"]["callbacks"][ "downstream_evaluator" ]["tasks"][task_name] @@ -182,19 +231,27 @@ def get_max_metrics_grouped( "select_final_test_miou_based_on_epoch_of_max_val_miou", False ), ) - if is_linear_probe_task and not is_select_best_by_primary_metric: + if ( + is_linear_probe_task + and get_test_metrics + and not is_select_best_by_primary_metric + ): print( - f"Skipping metric {key} for run {run.name} because it is a linear probe task but not done with early stop linear probing" + f"Skipping metric {normalized_key} for run {run.name} because it is a linear probe task but not done with early stop linear probing" ) continue - print( - f"Selecting metric {key} for run {run.name} because it matches criteria" - ) - prev_max_val = metrics.get(key, float("-inf")) - metrics[key] = max(prev_max_val, value) + prev_max_val = metrics.get(normalized_key, float("-inf")) + metrics[normalized_key] = max(prev_max_val, value) if value > prev_max_val: - max_runs_per_metric[key] = run + max_runs_per_metric[normalized_key] = run + + # Also record under the explicit sub-metric key + if additional_key is not None: + prev_max_val = metrics.get(additional_key, float("-inf")) + metrics[additional_key] = max(prev_max_val, value) + if value > prev_max_val: + max_runs_per_metric[additional_key] = run group_metrics[group_name] = metrics group_max_runs_per_metric[group_name] = max_runs_per_metric @@ -207,8 +264,19 @@ def get_max_metrics_grouped( for group_name, max_runs_per_metric in group_max_runs_per_metric.items(): test_metrics = {} for metric, run in max_runs_per_metric.items(): + # metric is already normalized to eval/ namespace (e.g. eval/mados/micro_f1) test_metric_key = metric.replace("eval/", "eval/test/") + # Check both eval/ and eval_other/ namespaces for test metrics. + # For post-PR#504 runs with remapped primary keys (eval/{task}/{primary_metric}), + # the actual test value may be at eval/test/{task} (the primary). value = run.summary.get(test_metric_key, None) + if value is None: + alt_key = test_metric_key.replace("eval/", "eval_other/", 1) + value = run.summary.get(alt_key, None) + if value is None: + # Try the primary test key (eval/test/{task}) + task_name = metric.split("/")[1] + value = run.summary.get(f"eval/test/{task_name}", None) if value is None: print( f"No test metric found for run {run.name} for metric {metric}" @@ -271,10 +339,11 @@ def get_max_metrics_per_partition( for run_id in run_ids: run = api.run(f"{WANDB_ENTITY}/{project_name}/{run_id}") for key, value in run.summary.items(): - if not key.startswith("eval/"): + if not (key.startswith("eval/") or key.startswith("eval_other/")): continue - partition_max_metrics[key] = max( - partition_max_metrics.get(key, value), value + normalized_key = _normalize_eval_key(key) + partition_max_metrics[normalized_key] = max( + partition_max_metrics.get(normalized_key, value), value ) partition_metrics[partition] = partition_max_metrics @@ -312,9 +381,10 @@ def get_max_metrics(project_name: str, run_prefix: str) -> dict[str, float]: for run_id in run_ids: run = api.run(f"{WANDB_ENTITY}/{project_name}/{run_id}") for key, value in run.summary.items(): - if not key.startswith("eval/"): + if not (key.startswith("eval/") or key.startswith("eval_other/")): continue - metrics[key] = max(metrics.get(key, value), value) + normalized_key = _normalize_eval_key(key) + metrics[normalized_key] = max(metrics.get(normalized_key, value), value) return metrics @@ -445,14 +515,19 @@ def serialize_max_settings_per_group( # Try original name key = f"eval/{metric}" val = partition_metrics[partition].get(key) + name_for_print = metric # Fallback with underscore variant if val is None: metric_alt = metric.replace("-", "_") key_alt = f"eval/{metric_alt}" val = partition_metrics[partition].get(key_alt) name_for_print = metric_alt if val is not None else metric - else: - name_for_print = metric + # also try the segmentation suffixes + if val is None: + metric_alt = f"{metric}/miou" + key_alt = f"eval/{metric_alt}" + val = partition_metrics[partition].get(key_alt) + name_for_print = metric_alt if val is not None else metric if val is None: print(f" {metric}: not found") @@ -482,35 +557,45 @@ def serialize_max_settings_per_group( serialize_max_settings_per_group( args.json_filename, group_max_runs_per_metric ) + + def _print_task_metrics( + metrics: dict[str, float], prefix: str, task_name: str + ) -> None: + """Print all sub-metrics for a task.""" + task_name_alt = task_name.replace("-", "_") + task_prefix = f"{prefix}/{task_name}/" + task_prefix_alt = f"{prefix}/{task_name_alt}/" + sub_metrics = { + k: v + for k, v in metrics.items() + if (k.startswith(task_prefix) or k.startswith(task_prefix_alt)) + and "/f1_class_" not in k # skip per-class f1 (too verbose) + } + if sub_metrics: + for k in sorted(sub_metrics): + # Extract just the metric name after the task + metric_name = k.split("/")[-1] + print(f" {task_name}/{metric_name}: {sub_metrics[k]}") + else: + # Fall back to eval/{task} (pre-PR#504 runs without sub-metrics) + for name in (task_name, task_name_alt): + k = f"{prefix}/{name}" + if k in metrics: + print(f" {name}: {metrics[k]}") + return + print(f" {task_name}: not found") + print("\nFinal Results:") for group_name, metrics in group_metrics.items(): print(f"\n{group_name}:") for metric in all_metrics: - try: - k = f"eval/{metric}" - print(f" {metric}: {metrics[k]}") - except KeyError: - try: - metric = metric.replace("-", "_") - k = f"eval/{metric}" - print(f" {metric}: {metrics[k]}") - except KeyError: - print(f" {metric}: not found") + _print_task_metrics(metrics, "eval", metric) if args.get_test_metrics: print("\nFinal Test Results:") for group_name, metrics in group_test_metrics.items(): print(f"\n{group_name}:") for metric in all_metrics: - try: - k = f"eval/test/{metric}" - print(f" {metric}: {metrics[k]}") - except KeyError: - try: - metric = metric.replace("-", "_") - k = f"eval/test/{metric}" - print(f" {metric}: {metrics[k]}") - except KeyError: - print(f" {metric}: not found") + _print_task_metrics(metrics, "eval/test", metric) # Save to CSV if args.output: