Skip to content

Commit 64c972d

Browse files
committed
Add pg2-benchmark validate cmd
1 parent 74985a6 commit 64c972d

4 files changed

Lines changed: 162 additions & 8 deletions

File tree

README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@ A model repo contains its README.md as a model card, which comes in two parts:
1818

1919
For more information, you can reference Hugging Face's [model cards](https://huggingface.co/docs/hub/en/model-cards).
2020

21+
In order to validate whether you containerise your model correclty, you can run:
22+
23+
```shell
24+
uv run pg2-benchmark validate <your_model_name>
25+
```
26+
27+
For example, after running `uv run pg2-benchmark validate esm`, you will get the following messages to ensure that the model [esm](models/esm/) is containerised correctly with the right model card and entrypoint:
28+
29+
```shell
30+
Uninstalled 34 packages in 504ms
31+
Installed 34 packages in 83ms
32+
✅ Loaded esm with hyper parameters {'location': 'esm2_t30_150M_UR50D', 'scoring_strategy': 'wt-marginals', 'nogpu': False, 'offset_idx': 24}.
33+
✅ Model esm has a valid 'train' entrypoint with required params: ['dataset_file', 'model_card_file']
34+
```
35+
2136
## Datasets
2237

2338
The datasets are included in the [dataset](datasets/) folder, where each dataset goes into a subfolder.
@@ -53,14 +68,14 @@ for dataset in datasets:
5368

5469
You can benchmark a group of supervised models:
5570
```shell
56-
dvc repro benchmark/supervised/local/dvc.yaml
71+
uv run dvc repro benchmark/supervised/local/dvc.yaml
5772
```
5873

5974
#### Zero-shot
6075

6176
You can benchmark a group of zero-shot models:
6277
```shell
63-
dvc repro benchmark/zero_shot/local/dvc.yaml
78+
uv run dvc repro benchmark/zero_shot/local/dvc.yaml
6479
```
6580

6681
### AWS environment
@@ -92,14 +107,14 @@ The difference of the AWS environment is that:
92107

93108
You can benchmark a group of supervised models:
94109
```shell
95-
AWS_ACCOUNT_ID=xxx AWS_PROFILE=yyy dvc repro benchmark/supervised/aws/dvc.yaml
110+
AWS_ACCOUNT_ID=xxx AWS_PROFILE=yyy uv run dvc repro benchmark/supervised/aws/dvc.yaml
96111
```
97112

98113
#### Zero-shot
99114

100115
You can benchmark a group of zero-shot models:
101116
```shell
102-
AWS_ACCOUNT_ID=xxx AWS_PROFILE=yyy dvc repro benchmark/zero_shot/aws/dvc.yaml
117+
AWS_ACCOUNT_ID=xxx AWS_PROFILE=yyy uv run dvc repro benchmark/zero_shot/aws/dvc.yaml
103118
```
104119

105120
## Generate dummy data

models/esm/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ build-backend = "hatchling.build"
2424

2525
[tool.uv.sources]
2626
pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" }
27-
pg2-benchmark = { path = "./pg2-benchmark", editable = true }
27+
pg2-benchmark = { git = "https://github.com/ProteinGym2/pg2-benchmark.git", rev = "main" }
2828

2929
[tool.hatch.build.targets.wheel]
3030
packages = ["src/pg2_model_esm"]

models/pls/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ build-backend = "hatchling.build"
2222

2323
[tool.uv.sources]
2424
pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" }
25-
pg2-benchmark = { path = "./pg2-benchmark", editable = true }
25+
pg2-benchmark = { git = "https://github.com/ProteinGym2/pg2-benchmark.git", rev = "main" }
2626

2727
[tool.hatch.build.targets.wheel]
2828
packages = ["src/pg2_model_pls"]

src/pg2_benchmark/__main__.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1+
import json
2+
import subprocess
3+
from pathlib import Path
4+
from typing import Annotated
5+
16
import typer
7+
28
from pg2_benchmark.cli.dataset import dataset_app
39
from pg2_benchmark.cli.metric import metric_app
410
from pg2_benchmark.cli.sagemaker import sagemaker_app
11+
from pg2_benchmark.model_card import ModelCard
12+
13+
14+
class ModelPath:
15+
ROOT_PATH = Path("models")
16+
SRC_PATH = Path("src")
17+
PACKAGE_PREFIX = "pg2_model"
18+
MODEL_CARD_PATH = Path("README.md")
19+
MAIN_PY_PATH = Path("__main__.py")
20+
COMMAND_NAME = "train"
21+
COMMAND_PARAMS = ["dataset_file", "model_card_file"]
22+
523

624
app = typer.Typer(
725
name="benchmark",
@@ -15,8 +33,129 @@
1533

1634

1735
@app.command()
18-
def ping():
19-
typer.echo("pong")
36+
def validate(
37+
model_name: Annotated[
38+
str, typer.Argument(help="The model name listed in the `models` folder")
39+
],
40+
):
41+
model_card_path = ModelPath.ROOT_PATH / model_name / ModelPath.MODEL_CARD_PATH
42+
43+
if not model_card_path.exists():
44+
typer.echo(
45+
f"❌ Model {model_name} does not have a model card at {model_card_path}"
46+
)
47+
raise typer.Exit(1)
48+
49+
try:
50+
model_card = ModelCard.from_path(model_card_path)
51+
typer.echo(
52+
f"✅ Loaded {model_card.name} with hyper parameters {model_card.hyper_params}."
53+
)
54+
55+
except Exception as e:
56+
typer.echo(f"❌ Error loading model card from {model_card_path}: {e}")
57+
raise typer.Exit(1)
58+
59+
main_py_path = (
60+
ModelPath.ROOT_PATH
61+
/ model_name
62+
/ ModelPath.SRC_PATH
63+
/ f"{ModelPath.PACKAGE_PREFIX}_{model_name}"
64+
/ ModelPath.MAIN_PY_PATH
65+
)
66+
67+
if not main_py_path.exists():
68+
typer.echo(
69+
f"❌ Model {model_name} does not have a {ModelPath.MAIN_PY_PATH} file at {main_py_path}"
70+
)
71+
raise typer.Exit(1)
72+
73+
try:
74+
result = subprocess.run(
75+
[
76+
"uv",
77+
"run",
78+
"--active",
79+
"python",
80+
"-c",
81+
f"""
82+
import importlib.util
83+
import inspect
84+
import json
85+
import sys
86+
87+
try:
88+
spec = importlib.util.spec_from_file_location('{ModelPath.PACKAGE_PREFIX}_{model_name}.__main__', '{ModelPath.PACKAGE_PREFIX}_{model_name}/__main__.py')
89+
module = importlib.util.module_from_spec(spec)
90+
spec.loader.exec_module(module)
91+
92+
app = getattr(module, "app")
93+
94+
entrypoint_command_found = False
95+
entrypoint_params_found = False
96+
97+
for command in app.registered_commands:
98+
if '{ModelPath.COMMAND_NAME}' == command.callback.__name__:
99+
entrypoint_command_found = True
100+
101+
sig = inspect.signature(command.callback)
102+
103+
if {ModelPath.COMMAND_PARAMS} == list(sig.parameters.keys()):
104+
entrypoint_params_found = True
105+
106+
break
107+
108+
validation_result = {{
109+
'success': True,
110+
'entrypoint_command_found': entrypoint_command_found,
111+
'entrypoint_params_found': entrypoint_params_found,
112+
'module_loaded': True
113+
}}
114+
115+
print(json.dumps(validation_result))
116+
117+
except Exception as e:
118+
error_result = {{
119+
'success': False,
120+
'entrypoint_command_found': False,
121+
'entrypoint_params_found': False,
122+
'module_loaded': False,
123+
'error': str(e)
124+
}}
125+
print(json.dumps(error_result))
126+
sys.exit(1)
127+
""",
128+
],
129+
cwd=ModelPath.ROOT_PATH / model_name / ModelPath.SRC_PATH,
130+
capture_output=True,
131+
text=True,
132+
)
133+
134+
if result.returncode == 0:
135+
validation_data = json.loads(result.stdout.strip())
136+
137+
if not validation_data["entrypoint_command_found"]:
138+
typer.echo(
139+
f"❌ Model {model_name} does not have a '{ModelPath.COMMAND_NAME}' command"
140+
)
141+
raise typer.Exit(1)
142+
143+
if not validation_data["entrypoint_params_found"]:
144+
typer.echo(
145+
f"❌ Model {model_name}'s '{ModelPath.COMMAND_NAME}' command does not have the required params: {ModelPath.COMMAND_PARAMS}"
146+
)
147+
raise typer.Exit(1)
148+
149+
typer.echo(
150+
f"✅ Model {model_name} has a valid '{ModelPath.COMMAND_NAME}' entrypoint with required params: {ModelPath.COMMAND_PARAMS}"
151+
)
152+
else:
153+
typer.echo(f"❌ Error loading module {main_py_path}: {result.stderr}")
154+
raise typer.Exit(1)
155+
156+
except Exception as e:
157+
typer.echo(f"❌ Error running validation subprocess: {e}")
158+
raise typer.Exit(1)
20159

21160

22161
if __name__ == "__main__":

0 commit comments

Comments
 (0)