From 722ad3a2441600f85eded9533d54350842a96c24 Mon Sep 17 00:00:00 2001 From: MiguelAFH Date: Tue, 12 Aug 2025 20:08:01 -0700 Subject: [PATCH] feat: added support for accruacy metrics on yes/no scenarios --- data/accuracy_metrics.csv | 28 +++ scripts/compute_accuracy_metrics.sh | 15 ++ src/medhelm/compute_accuracy_metrics.py | 259 ++++++++++++++++++++++++ 3 files changed, 302 insertions(+) create mode 100644 data/accuracy_metrics.csv create mode 100644 scripts/compute_accuracy_metrics.sh create mode 100644 src/medhelm/compute_accuracy_metrics.py diff --git a/data/accuracy_metrics.csv b/data/accuracy_metrics.csv new file mode 100644 index 0000000..db4f422 --- /dev/null +++ b/data/accuracy_metrics.csv @@ -0,0 +1,28 @@ +scenario,model,precision,recall,f1_score,TP,FP,FN,TN,total +ehrshot,deepseek_r1,0.08460304731355253,0.7902621722846442,0.15284317276349152,211,2283,56,450,3000 +ehrshot,o3_mini_2025_01_31,0.09447236180904522,0.352059925093633,0.14896988906497624,94,901,173,1832,3000 +n2c2_ct_matching,deepseek_r1,0.7962962962962963,0.8686868686868687,0.8309178743961353,86,22,13,137,258 +n2c2_ct_matching,o3_mini_2025_01_31,0.839622641509434,0.898989898989899,0.8682926829268292,89,17,10,142,258 +ehrshot,claude_3_5_sonnet_20241022,0.07747196738022426,0.5692883895131086,0.1363840287124271,152,1810,115,923,3000 +ehrshot,claude_3_7_sonnet_20250219,0.07441574415744158,0.45318352059925093,0.1278394083465399,121,1505,146,1228,3000 +ehrshot,gemini_1.5_pro_001,0.09145129224652088,0.5168539325842697,0.1554054054054054,138,1371,129,1362,3000 +ehrshot,gemini_2.0_flash_001,0.0967741935483871,0.550561797752809,0.16461366181410972,147,1372,120,1361,3000 +ehrshot,llama_3.3_70b_instruct,0.12121212121212122,0.4794007490636704,0.19349962207105068,128,928,139,1805,3000 +ehrshot,gpt_4o_2024_05_13,0.07832699619771863,0.3857677902621723,0.1302149178255373,103,1212,164,1521,3000 +ehrshot,gpt_4o_mini_2024_07_18,0.11764705882352941,0.00749063670411985,0.014084507042253521,2,15,265,2718,3000 +n2c2_ct_matching,claude_3_5_sonnet_20241022,0.7592592592592593,0.8282828282828283,0.7922705314009661,82,26,17,133,258 +n2c2_ct_matching,claude_3_7_sonnet_20250219,0.7217391304347827,0.8383838383838383,0.7757009345794392,83,32,16,127,258 +n2c2_ct_matching,gemini_1.5_pro_001,0.6771653543307087,0.8686868686868687,0.7610619469026549,86,41,13,118,258 +n2c2_ct_matching,gemini_2.0_flash_001,0.5792682926829268,0.9595959595959596,0.7224334600760455,95,69,4,90,258 +n2c2_ct_matching,llama_3.3_70b_instruct,0.8072289156626506,0.6767676767676768,0.7362637362637363,67,16,32,143,258 +n2c2_ct_matching,gpt_4o_2024_05_13,0.7570093457943925,0.8181818181818182,0.7864077669902914,81,26,18,133,258 +n2c2_ct_matching,gpt_4o_mini_2024_07_18,0.656934306569343,0.9090909090909091,0.7627118644067796,90,47,9,112,258 +race_based_med,deepseek_r1,0.9487179487179487,0.8809523809523809,0.9135802469135802,74,4,10,79,167 +race_based_med,o3_mini_2025_01_31,0.9315068493150684,0.8095238095238095,0.8662420382165604,68,5,16,78,167 +race_based_med,claude_3_5_sonnet_20241022,0.8484848484848485,0.6666666666666666,0.7466666666666666,56,10,28,73,167 +race_based_med,claude_3_7_sonnet_20250219,0.9655172413793104,0.3333333333333333,0.495575221238938,28,1,56,82,167 +race_based_med,gemini_1.5_pro_001,0.9821428571428571,0.6547619047619048,0.7857142857142857,55,1,29,82,167 +race_based_med,gemini_2.0_flash_001,0.9054054054054054,0.7976190476190477,0.8481012658227848,67,7,17,76,167 +race_based_med,llama_3.3_70b_instruct,1.0,0.14285714285714285,0.25,12,0,72,83,167 +race_based_med,gpt_4o_2024_05_13,0.9036144578313253,0.8928571428571429,0.8982035928143712,75,8,9,75,167 +race_based_med,gpt_4o_mini_2024_07_18,0.9411764705882353,0.5714285714285714,0.7111111111111111,48,3,36,80,167 diff --git a/scripts/compute_accuracy_metrics.sh b/scripts/compute_accuracy_metrics.sh new file mode 100644 index 0000000..be06520 --- /dev/null +++ b/scripts/compute_accuracy_metrics.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +cd "$SCRIPT_DIR/$DIR" || { echo "Failed to change directory"; exit 1; } + +# Construct the log file name +DATE=$(date '+%Y-%m-%d_%H-%M-%S') +LOG_FILE="../logs/compute_accuracy_metrics_$DATE.log" +mkdir -p ../logs # Ensure the logs directory exists +exec > >(tee -a "$LOG_FILE") 2>&1 + +python3 ../src/medhelm/compute_accuracy_metrics.py \ + -b ../../rebuttal/benchmark_output \ + -o ../data/accuracy_metrics.csv + diff --git a/src/medhelm/compute_accuracy_metrics.py b/src/medhelm/compute_accuracy_metrics.py new file mode 100644 index 0000000..3644b5b --- /dev/null +++ b/src/medhelm/compute_accuracy_metrics.py @@ -0,0 +1,259 @@ +import argparse +import csv +import os +import json +import pandas as pd + +from typing import Dict, List, Any +from tqdm import tqdm + +from helm.benchmark.run_spec import RunSpec +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.scenarios.scenario import Instance +from helm.common.codec import from_json +from helm.common.hierarchical_logger import hlog +from helm.benchmark.metrics.metric import PerInstanceStats + +from medhelm.utils.constants import ( + BENCHMARK_NAME_MAPPING +) + +PER_INSTANCE_STATS_FILE_NAME = "per_instance_stats.json" +SCENARIO_FILE_NAME = "scenario.json" +SCENARIO_STATE_FILE_NAME = "scenario_state.json" +RUN_SPEC_FILE_NAME = "run_spec.json" +INSTANCES_FILE_NAME = "instances.json" +TEST = "test" +MAIN_METRIC_RANGE = [1, 5] + +BENCHMARKS = [ + "EHRSHOT", #yes, no + # "ADHD-Behavior", + # "ADHD-MedEffects", + # "MedConfInfo" + # "ProxySender" # This has 3 possible answers, + # "PrivacyDetection", + # "PubMedQA" # This has 3 possible answers, + # "BMT-Status", + "RaceBias", #yes, no + "N2C2", #yes, no + # "HospiceReferral", + # "ClinicReferral", + # "CDI-QA", + # "ENT-Referral" # This has 3 possible answers +] + + +def read_per_instance_stats(per_instance_stats_path: str) -> List[PerInstanceStats]: + if not os.path.exists(per_instance_stats_path): + raise ValueError(f"Could not load [PerInstanceStats] from {per_instance_stats_path}") + with open(per_instance_stats_path) as f: + return from_json(f.read(), List[PerInstanceStats]) + + +def read_run_spec(run_spec_path: str) -> RunSpec: + if not os.path.exists(run_spec_path): + raise ValueError(f"Could not load RunSpec from {run_spec_path}") + with open(run_spec_path) as f: + return from_json(f.read(), RunSpec) + + +def read_scenario(scenario_path: str) -> Dict[str, Any]: + with open(scenario_path) as scenario_file: + scenario = json.load(scenario_file) + return scenario + + +def read_scenario_state(scenario_state_path: str) -> ScenarioState: + if not os.path.exists(scenario_state_path): + raise ValueError(f"Could not load ScenarioState from {scenario_state_path}") + with open(scenario_state_path) as f: + return from_json(f.read(), ScenarioState) + +def read_instances(instance_state_path: str) -> List[Instance]: + if not os.path.exists(instance_state_path): + raise ValueError(f"Could not load [Instance] from {instance_state_path}") + with open(instance_state_path) as f: + return from_json(f.read(), List[Instance]) + + +def get_run_dirs(runs_path: str) -> List[str]: + suite_names = [ + p for p in os.listdir(runs_path) + if os.path.isdir(os.path.join(runs_path, p)) + ] + run_dir_names = [] + for suite in suite_names: + run_suite_path = os.path.join(runs_path, suite) + for p in os.listdir(run_suite_path): + full_path = os.path.join(run_suite_path, p) + if p not in {"eval_cache", "groups"} and os.path.isdir(full_path): + run_dir_names.append(os.path.join(suite, p)) + run_dir_names.sort() + return run_dir_names + +def get_request_state(scenario_state: ScenarioState, id: str) -> RequestState: + request_state = None + for state in scenario_state.request_states: + if state.instance.id == id: + request_state = state + break + return request_state + +def compute_labels( + per_instance_stats_list: List[PerInstanceStats], + instances: List[Instance], + scenario_name: str, + model: str, + subgroup: str +) -> Dict[str, float]: + stats = [] + for i, stat in enumerate(per_instance_stats_list): + instance = instances[i] + assert stat.instance_id in instance.id, f"Instance ID mismatch: {stat.instance_id} != {instance.id}" + + # Find the label from instance.references where "correct" is in tags + label = None + for ref in instance.references: + if "correct" in ref.tags: + label = ref.output.text + break + + # Find the "exact_match" stat + exact_match_stat = None + for s in stat.stats: + if s.name.name == "exact_match": + exact_match_stat = s + break + + # Derive prediction: if sum == 1.0, prediction == label, else prediction != label + if exact_match_stat is not None and label is not None: + if exact_match_stat.sum == 1: + prediction = label + else: + if label == "no": + prediction = "yes" + else: + prediction = "no" + else: + prediction = None + + stats.append({ + "instance_id": stat.instance_id, + "label": label, + "prediction": prediction, + "scenario_name": scenario_name, + "model": model, + "subgroup": subgroup + }) + return stats + + +def get_classification_counts(metrics: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: + counts = {} + for row in metrics: + scenario_name = row["scenario_name"] + model = row["model"] + key = f"{scenario_name},{model}" + if key not in counts: + counts[key] = { + "TP": 0, + "FP": 0, + "FN": 0, + "TN": 0, + } + + counts[key]["TP"] += row["label"] == "yes" and row["prediction"] == "yes" + counts[key]["FP"] += row["label"] == "no" and row["prediction"] == "yes" + counts[key]["FN"] += row["label"] == "yes" and row["prediction"] == "no" + counts[key]["TN"] += row["label"] == "no" and row["prediction"] == "no" + return counts + + +def get_precision_recall_f1(counts: Dict[str, Dict[str, int]]) -> List[Dict[str, float]]: + results = [] + for key, metrics in counts.items(): + TP = metrics['TP'] + FP = metrics['FP'] + FN = metrics['FN'] + TN = metrics['TN'] + + precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0 + recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + + scenario_name, model = key.split(",") + results.append({ + 'scenario': scenario_name, + 'model': model, + 'precision': precision, + 'recall': recall, + 'f1_score': f1, + 'TP': TP, + 'FP': FP, + 'FN': FN, + 'TN': TN, + "total": TP + FP + FN + TN, + }) + return results + + +def main( + runs_path: str, + output_path: str +) -> None: + metrics = [] + run_dir_names = get_run_dirs(runs_path) + for run_dir_name in tqdm(run_dir_names, disable=None): + scenario_path = os.path.join(runs_path, run_dir_name, SCENARIO_FILE_NAME) + scenario = read_scenario(scenario_path) + scenario_name = scenario["name"] + run_spec_path = os.path.join(runs_path, run_dir_name, RUN_SPEC_FILE_NAME) + run_spec = read_run_spec(run_spec_path) + model = run_spec.adapter_spec.model.split("/")[-1].replace("-", "_") + if scenario_name not in BENCHMARK_NAME_MAPPING: + continue + if BENCHMARK_NAME_MAPPING[scenario_name] not in BENCHMARKS: + continue + print(f"Computing metrics for: {BENCHMARK_NAME_MAPPING[scenario_name]}, {model}") + per_instance_stats_path = os.path.join(runs_path, run_dir_name, PER_INSTANCE_STATS_FILE_NAME) + per_instance_stats_list = read_per_instance_stats(per_instance_stats_path) + instances_path = os.path.join(runs_path, run_dir_name, INSTANCES_FILE_NAME) + instances = read_instances(instances_path) + subgroup = "" + if len(run_spec.scenario_spec.args) > 0: + subgroup = run_spec.scenario_spec.args + + metrics.extend( + compute_labels( + per_instance_stats_list, + instances, + scenario_name, + model, + subgroup + ) + ) + + counts = get_classification_counts(metrics) + results = get_precision_recall_f1(counts) + + for metrics in results: + print(f"{metrics['scenario']}, {metrics['model']}: Precision = {metrics['precision']:.4f}, Recall = {metrics['recall']:.4f}, F1 Score = {metrics['f1_score']:.4f}") + + print(f"Total benchmarks processed: {len(results)}") + print("Writing results to output file:", output_path) + df= pd.DataFrame(results) + df.to_csv(output_path, index=False) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--benchmark_path", "-b", type=str, required=True, help="Path to the directory containing run outputs") + parser.add_argument("--output_path", "-o", type=str, required=True, help="Output path for the accuracy metrics csv") + args = parser.parse_args() + + main( + runs_path=f"{args.benchmark_path}/runs", + output_path=args.output_path, + )