Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions data/accuracy_metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
scenario,model,precision,recall,f1_score,TP,FP,FN,TN,total
ehrshot,deepseek_r1,0.08460304731355253,0.7902621722846442,0.15284317276349152,211,2283,56,450,3000
ehrshot,o3_mini_2025_01_31,0.09447236180904522,0.352059925093633,0.14896988906497624,94,901,173,1832,3000
n2c2_ct_matching,deepseek_r1,0.7962962962962963,0.8686868686868687,0.8309178743961353,86,22,13,137,258
n2c2_ct_matching,o3_mini_2025_01_31,0.839622641509434,0.898989898989899,0.8682926829268292,89,17,10,142,258
ehrshot,claude_3_5_sonnet_20241022,0.07747196738022426,0.5692883895131086,0.1363840287124271,152,1810,115,923,3000
ehrshot,claude_3_7_sonnet_20250219,0.07441574415744158,0.45318352059925093,0.1278394083465399,121,1505,146,1228,3000
ehrshot,gemini_1.5_pro_001,0.09145129224652088,0.5168539325842697,0.1554054054054054,138,1371,129,1362,3000
ehrshot,gemini_2.0_flash_001,0.0967741935483871,0.550561797752809,0.16461366181410972,147,1372,120,1361,3000
ehrshot,llama_3.3_70b_instruct,0.12121212121212122,0.4794007490636704,0.19349962207105068,128,928,139,1805,3000
ehrshot,gpt_4o_2024_05_13,0.07832699619771863,0.3857677902621723,0.1302149178255373,103,1212,164,1521,3000
ehrshot,gpt_4o_mini_2024_07_18,0.11764705882352941,0.00749063670411985,0.014084507042253521,2,15,265,2718,3000
n2c2_ct_matching,claude_3_5_sonnet_20241022,0.7592592592592593,0.8282828282828283,0.7922705314009661,82,26,17,133,258
n2c2_ct_matching,claude_3_7_sonnet_20250219,0.7217391304347827,0.8383838383838383,0.7757009345794392,83,32,16,127,258
n2c2_ct_matching,gemini_1.5_pro_001,0.6771653543307087,0.8686868686868687,0.7610619469026549,86,41,13,118,258
n2c2_ct_matching,gemini_2.0_flash_001,0.5792682926829268,0.9595959595959596,0.7224334600760455,95,69,4,90,258
n2c2_ct_matching,llama_3.3_70b_instruct,0.8072289156626506,0.6767676767676768,0.7362637362637363,67,16,32,143,258
n2c2_ct_matching,gpt_4o_2024_05_13,0.7570093457943925,0.8181818181818182,0.7864077669902914,81,26,18,133,258
n2c2_ct_matching,gpt_4o_mini_2024_07_18,0.656934306569343,0.9090909090909091,0.7627118644067796,90,47,9,112,258
race_based_med,deepseek_r1,0.9487179487179487,0.8809523809523809,0.9135802469135802,74,4,10,79,167
race_based_med,o3_mini_2025_01_31,0.9315068493150684,0.8095238095238095,0.8662420382165604,68,5,16,78,167
race_based_med,claude_3_5_sonnet_20241022,0.8484848484848485,0.6666666666666666,0.7466666666666666,56,10,28,73,167
race_based_med,claude_3_7_sonnet_20250219,0.9655172413793104,0.3333333333333333,0.495575221238938,28,1,56,82,167
race_based_med,gemini_1.5_pro_001,0.9821428571428571,0.6547619047619048,0.7857142857142857,55,1,29,82,167
race_based_med,gemini_2.0_flash_001,0.9054054054054054,0.7976190476190477,0.8481012658227848,67,7,17,76,167
race_based_med,llama_3.3_70b_instruct,1.0,0.14285714285714285,0.25,12,0,72,83,167
race_based_med,gpt_4o_2024_05_13,0.9036144578313253,0.8928571428571429,0.8982035928143712,75,8,9,75,167
race_based_med,gpt_4o_mini_2024_07_18,0.9411764705882353,0.5714285714285714,0.7111111111111111,48,3,36,80,167
15 changes: 15 additions & 0 deletions scripts/compute_accuracy_metrics.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/$DIR" || { echo "Failed to change directory"; exit 1; }

# Construct the log file name
DATE=$(date '+%Y-%m-%d_%H-%M-%S')
LOG_FILE="../logs/compute_accuracy_metrics_$DATE.log"
mkdir -p ../logs # Ensure the logs directory exists
exec > >(tee -a "$LOG_FILE") 2>&1

python3 ../src/medhelm/compute_accuracy_metrics.py \
-b ../../rebuttal/benchmark_output \
-o ../data/accuracy_metrics.csv

259 changes: 259 additions & 0 deletions src/medhelm/compute_accuracy_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import argparse
import csv
import os
import json
import pandas as pd

from typing import Dict, List, Any
from tqdm import tqdm

from helm.benchmark.run_spec import RunSpec
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.scenarios.scenario import Instance
from helm.common.codec import from_json
from helm.common.hierarchical_logger import hlog
from helm.benchmark.metrics.metric import PerInstanceStats

from medhelm.utils.constants import (
BENCHMARK_NAME_MAPPING
)

PER_INSTANCE_STATS_FILE_NAME = "per_instance_stats.json"
SCENARIO_FILE_NAME = "scenario.json"
SCENARIO_STATE_FILE_NAME = "scenario_state.json"
RUN_SPEC_FILE_NAME = "run_spec.json"
INSTANCES_FILE_NAME = "instances.json"
TEST = "test"
MAIN_METRIC_RANGE = [1, 5]

BENCHMARKS = [
"EHRSHOT", #yes, no
# "ADHD-Behavior",
# "ADHD-MedEffects",
# "MedConfInfo"
# "ProxySender" # This has 3 possible answers,
# "PrivacyDetection",
# "PubMedQA" # This has 3 possible answers,
# "BMT-Status",
"RaceBias", #yes, no
"N2C2", #yes, no
# "HospiceReferral",
# "ClinicReferral",
# "CDI-QA",
# "ENT-Referral" # This has 3 possible answers
]


def read_per_instance_stats(per_instance_stats_path: str) -> List[PerInstanceStats]:
if not os.path.exists(per_instance_stats_path):
raise ValueError(f"Could not load [PerInstanceStats] from {per_instance_stats_path}")
with open(per_instance_stats_path) as f:
return from_json(f.read(), List[PerInstanceStats])


def read_run_spec(run_spec_path: str) -> RunSpec:
if not os.path.exists(run_spec_path):
raise ValueError(f"Could not load RunSpec from {run_spec_path}")
with open(run_spec_path) as f:
return from_json(f.read(), RunSpec)


def read_scenario(scenario_path: str) -> Dict[str, Any]:
with open(scenario_path) as scenario_file:
scenario = json.load(scenario_file)
return scenario


def read_scenario_state(scenario_state_path: str) -> ScenarioState:
if not os.path.exists(scenario_state_path):
raise ValueError(f"Could not load ScenarioState from {scenario_state_path}")
with open(scenario_state_path) as f:
return from_json(f.read(), ScenarioState)

def read_instances(instance_state_path: str) -> List[Instance]:
if not os.path.exists(instance_state_path):
raise ValueError(f"Could not load [Instance] from {instance_state_path}")
with open(instance_state_path) as f:
return from_json(f.read(), List[Instance])


def get_run_dirs(runs_path: str) -> List[str]:
suite_names = [
p for p in os.listdir(runs_path)
if os.path.isdir(os.path.join(runs_path, p))
]
run_dir_names = []
for suite in suite_names:
run_suite_path = os.path.join(runs_path, suite)
for p in os.listdir(run_suite_path):
full_path = os.path.join(run_suite_path, p)
if p not in {"eval_cache", "groups"} and os.path.isdir(full_path):
run_dir_names.append(os.path.join(suite, p))
run_dir_names.sort()
return run_dir_names

def get_request_state(scenario_state: ScenarioState, id: str) -> RequestState:
request_state = None
for state in scenario_state.request_states:
if state.instance.id == id:
request_state = state
break
return request_state

def compute_labels(
per_instance_stats_list: List[PerInstanceStats],
instances: List[Instance],
scenario_name: str,
model: str,
subgroup: str
) -> Dict[str, float]:
stats = []
for i, stat in enumerate(per_instance_stats_list):
instance = instances[i]
assert stat.instance_id in instance.id, f"Instance ID mismatch: {stat.instance_id} != {instance.id}"

# Find the label from instance.references where "correct" is in tags
label = None
for ref in instance.references:
if "correct" in ref.tags:
label = ref.output.text
break

# Find the "exact_match" stat
exact_match_stat = None
for s in stat.stats:
if s.name.name == "exact_match":
exact_match_stat = s
break

# Derive prediction: if sum == 1.0, prediction == label, else prediction != label
if exact_match_stat is not None and label is not None:
if exact_match_stat.sum == 1:
prediction = label
else:
if label == "no":
prediction = "yes"
else:
prediction = "no"
else:
prediction = None

stats.append({
"instance_id": stat.instance_id,
"label": label,
"prediction": prediction,
"scenario_name": scenario_name,
"model": model,
"subgroup": subgroup
})
return stats


def get_classification_counts(metrics: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
counts = {}
for row in metrics:
scenario_name = row["scenario_name"]
model = row["model"]
key = f"{scenario_name},{model}"
if key not in counts:
counts[key] = {
"TP": 0,
"FP": 0,
"FN": 0,
"TN": 0,
}

counts[key]["TP"] += row["label"] == "yes" and row["prediction"] == "yes"
counts[key]["FP"] += row["label"] == "no" and row["prediction"] == "yes"
counts[key]["FN"] += row["label"] == "yes" and row["prediction"] == "no"
counts[key]["TN"] += row["label"] == "no" and row["prediction"] == "no"
return counts


def get_precision_recall_f1(counts: Dict[str, Dict[str, int]]) -> List[Dict[str, float]]:
results = []
for key, metrics in counts.items():
TP = metrics['TP']
FP = metrics['FP']
FN = metrics['FN']
TN = metrics['TN']

precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

scenario_name, model = key.split(",")
results.append({
'scenario': scenario_name,
'model': model,
'precision': precision,
'recall': recall,
'f1_score': f1,
'TP': TP,
'FP': FP,
'FN': FN,
'TN': TN,
"total": TP + FP + FN + TN,
})
return results


def main(
runs_path: str,
output_path: str
) -> None:
metrics = []
run_dir_names = get_run_dirs(runs_path)
for run_dir_name in tqdm(run_dir_names, disable=None):
scenario_path = os.path.join(runs_path, run_dir_name, SCENARIO_FILE_NAME)
scenario = read_scenario(scenario_path)
scenario_name = scenario["name"]
run_spec_path = os.path.join(runs_path, run_dir_name, RUN_SPEC_FILE_NAME)
run_spec = read_run_spec(run_spec_path)
model = run_spec.adapter_spec.model.split("/")[-1].replace("-", "_")
if scenario_name not in BENCHMARK_NAME_MAPPING:
continue
if BENCHMARK_NAME_MAPPING[scenario_name] not in BENCHMARKS:
continue
print(f"Computing metrics for: {BENCHMARK_NAME_MAPPING[scenario_name]}, {model}")
per_instance_stats_path = os.path.join(runs_path, run_dir_name, PER_INSTANCE_STATS_FILE_NAME)
per_instance_stats_list = read_per_instance_stats(per_instance_stats_path)
instances_path = os.path.join(runs_path, run_dir_name, INSTANCES_FILE_NAME)
instances = read_instances(instances_path)
subgroup = ""
if len(run_spec.scenario_spec.args) > 0:
subgroup = run_spec.scenario_spec.args

metrics.extend(
compute_labels(
per_instance_stats_list,
instances,
scenario_name,
model,
subgroup
)
)

counts = get_classification_counts(metrics)
results = get_precision_recall_f1(counts)

for metrics in results:
print(f"{metrics['scenario']}, {metrics['model']}: Precision = {metrics['precision']:.4f}, Recall = {metrics['recall']:.4f}, F1 Score = {metrics['f1_score']:.4f}")

print(f"Total benchmarks processed: {len(results)}")
print("Writing results to output file:", output_path)
df= pd.DataFrame(results)
df.to_csv(output_path, index=False)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--benchmark_path", "-b", type=str, required=True, help="Path to the directory containing run outputs")
parser.add_argument("--output_path", "-o", type=str, required=True, help="Output path for the accuracy metrics csv")
args = parser.parse_args()

main(
runs_path=f"{args.benchmark_path}/runs",
output_path=args.output_path,
)