Skip to content

Commit a94c750

Browse files
Merge pull request #124 from ProteinGym/chore/move-metric-calc
PR1: Move metric.py to scripts
2 parents 65bedcd + 7dac9b4 commit a94c750

File tree

7 files changed

+145
-38
lines changed

7 files changed

+145
-38
lines changed

benchmark/supervised/local/dvc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ stages:
3434
dataset: ${datasets}
3535
model: ${models}
3636

37-
cmd: uv run proteingym-benchmark metric calc --output-path ${output.prediction}/${item.dataset.name}_${item.model.name}.csv --metric-path ${output.metric}/${item.dataset.name}_${item.model.name}.csv
37+
cmd: python $(dvc root)/scripts/metric.py --output ${output.prediction}/${item.dataset.name}_${item.model.name}.csv --metric ${output.metric}/${item.dataset.name}_${item.model.name}.csv --actual-vector-col "test" --predict-vector-col "pred"
3838
deps:
3939
- ${output.prediction}/${item.dataset.name}_${item.model.name}.csv
4040
outs:

benchmark/zero_shot/local/dvc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ stages:
3434
dataset: ${datasets}
3535
model: ${models}
3636

37-
cmd: uv run proteingym-benchmark metric calc --output-path ${output.prediction}/${item.dataset.name}_${item.model.name}.csv --metric-path ${output.metric}/${item.dataset.name}_${item.model.name}.csv
37+
cmd: python $(dvc root)/scripts/metric.py --output ${output.prediction}/${item.dataset.name}_${item.model.name}.csv --metric ${output.metric}/${item.dataset.name}_${item.model.name}.csv --actual-vector-col "test" --predict-vector-col "pred"
3838
deps:
3939
- ${output.prediction}/${item.dataset.name}_${item.model.name}.csv
4040
outs:

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
polars[pyarrow]>=1.30.0,<2.0.0
2+
pycm==4.4

scripts/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Scripts
2+
3+
This directory contains utility scripts for the ProteinGym benchmark.
4+
5+
## Dependencies
6+
7+
Make sure to install the required dependencies:
8+
9+
```shell
10+
pip install -r requirements.txt
11+
```
12+
13+
## metric.py
14+
15+
The [metric.py](metric.py) script calculates performance metrics for machine learning models by comparing actual and predicted values.
16+
17+
### Arguments
18+
19+
- `--output`: Path to the CSV file containing prediction results
20+
- `--metric`: Path where the calculated metrics CSV will be saved
21+
- `--actual-vector-col`: Column name containing actual/ground truth values
22+
- `--predict-vector-col`: Column name containing predicted values
23+
24+
### Example
25+
26+
```shell
27+
python metric.py \
28+
--output predictions.csv \
29+
--metric metrics.csv \
30+
--actual-vector-col "true_values" \
31+
--predict-vector-col "predicted_values"
32+
```

scripts/metric.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Metric calculation script for ProteinGym benchmark evaluation.
3+
4+
This script provides functionality to calculate performance metrics for machine learning models
5+
by comparing actual and predicted values. It computes classification metrics via confusion
6+
matrix from CSV output files.
7+
8+
The main function `calc` reads prediction results from a CSV file, generates a confusion matrix
9+
with comprehensive classification statistics, and outputs all metrics to a CSV file for further analysis.
10+
11+
Example output CSV:
12+
| Metric | Value |
13+
|--------------|------------|
14+
| Overall ACC | 0.85 |
15+
| PPV Macro | 'None' |
16+
| Kappa 95% CI | (0.0, 0.0) |
17+
18+
Functions:
19+
calc: Calculate and save performance metrics from prediction output files
20+
"""
21+
22+
import argparse
23+
from pathlib import Path
24+
25+
import polars as pl
26+
from pycm import ConfusionMatrix
27+
28+
29+
def calc(
30+
output: Path, metric: Path, actual_vector_col: str, predict_vector_col: str
31+
) -> Path:
32+
"""Calculate performance metrics from prediction output and save to CSV.
33+
34+
Reads prediction results from a CSV file, computes classification metrics using
35+
a confusion matrix. All metrics are saved to a CSV file.
36+
37+
Args:
38+
output: Path to the CSV file containing prediction results
39+
metric: Path where the calculated metrics CSV will be saved
40+
actual_vector_col: Column name containing actual/ground truth values
41+
predict_vector_col: Column name containing predicted values
42+
"""
43+
44+
print("Start to calculate metrics.")
45+
46+
output_dataframe = pl.read_csv(output)
47+
48+
cm = ConfusionMatrix(
49+
actual_vector=output_dataframe[actual_vector_col].to_list(),
50+
predict_vector=output_dataframe[predict_vector_col].to_list(),
51+
)
52+
53+
metrics_data = [
54+
{"metric_name": key, "metric_value": str(value)}
55+
for key, value in cm.overall_stat.items()
56+
]
57+
58+
metric_dataframe = pl.DataFrame(
59+
data=metrics_data,
60+
schema={"metric_name": pl.String, "metric_value": pl.String},
61+
)
62+
63+
metric_dataframe.write_csv(metric)
64+
65+
return metric
66+
67+
68+
def main():
69+
parser = argparse.ArgumentParser(
70+
description="Calculate metric for ProteinGym benchmark evaluation."
71+
)
72+
73+
parser.add_argument(
74+
"--output",
75+
type=Path,
76+
required=True,
77+
help="Path to the CSV file containing prediction results",
78+
)
79+
parser.add_argument(
80+
"--metric",
81+
type=Path,
82+
required=True,
83+
help="Path where the calculated metrics CSV will be saved",
84+
)
85+
parser.add_argument(
86+
"--actual-vector-col",
87+
type=str,
88+
required=True,
89+
help="Column name containing actual/ground truth values",
90+
)
91+
parser.add_argument(
92+
"--predict-vector-col",
93+
type=str,
94+
required=True,
95+
help="Column name containing predicted values",
96+
)
97+
98+
args = parser.parse_args()
99+
100+
return calc(
101+
output=args.output,
102+
metric=args.metric,
103+
actual_vector_col=args.actual_vector_col,
104+
predict_vector_col=args.predict_vector_col,
105+
)
106+
107+
108+
if __name__ == "__main__":
109+
print(f"Metrics have been saved to {main()}.")

src/proteingym/benchmark/__main__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import typer
55

66
from .__about__ import __version__
7-
from .cli.metric import metric_app
87
from .cli.sagemaker import sagemaker_app
98

109
app = typer.Typer(
@@ -13,7 +12,6 @@
1312
add_completion=False,
1413
)
1514

16-
app.add_typer(metric_app, name="metric", help="Metric operations")
1715
app.add_typer(sagemaker_app, name="sagemaker", help="SageMaker operations")
1816

1917

src/proteingym/benchmark/cli/metric.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)