Skip to content

Commit 9f1edd6

Browse files
Merge pull request #88 from ProteinGym2/chore/standardize-model
PR3: Structure models API to facilitate easy-to-read model containerisation.
2 parents 3684ee0 + c898000 commit 9f1edd6

File tree

9 files changed

+600
-380
lines changed

9 files changed

+600
-380
lines changed

models/README.md

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Model
2+
3+
This README details how to add a model to the benchmark.
4+
5+
## Entrypoints
6+
7+
A model requires only one entrypoint: the `train` method, which you can referecen from below two models:
8+
9+
* [esm/src/pg2_model_esm/__main__.py](esm/src/pg2_model_esm/__main__.py)
10+
* [pls/src/pg2_model_pls/__main__.py](pls/src/pg2_model_pls/__main__.py)
11+
12+
Both **supervised** models and **zero-shot** models call this `train` method, because it is the glue method to glue the packages: `pg2-dataset`, `pg2-benchmark` and the models' original source code together. The method is named `train`, because for SageMaker, it looks for the `train` method as a entrypoint, thus it becomes the common method for both environments: local and AWS.
13+
14+
This entrypoint expects a reference to a dataset, e.g., loaded by `pg2-dataset`:
15+
16+
```python
17+
from pg2_dataset.dataset import Dataset
18+
dataset = Dataset.from_path(dataset_file)
19+
```
20+
21+
Additionally, this entrypoint also expects a reference to a model card, e.g., loaded by `pg2-benchmark`:
22+
23+
```
24+
from pg2_benchmark.manifest import Manifest
25+
manifest = Manifest.from_path(model_toml_file)
26+
```
27+
28+
Finally, inside this `train` method:
29+
30+
* For a **supervised** model, like [esm](esm/), it calls `load` and `infer` in order:
31+
* `load` uses `manifest` as input, and returns a model object as output.
32+
* `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
33+
34+
* For a **zero-shot** model, like [pls](pls/), it calls `train` and `infer` in order:
35+
* `train` uses `dataset` and `manifest` as input, and returns a model object as output.
36+
* `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
37+
38+
The result data frame is saved on the disk in the local environment and stored in AWS S3 in the cloud environment. After the container is destroyed, the result data frame is persisted for the later metric calculation.
39+
40+
For reference, below an example Python implementation with `typer`:
41+
42+
``` python
43+
# In `__main__.py`
44+
import typer
45+
from pg2_dataset import Dataset
46+
from pg2_benchmark import Manifest
47+
48+
49+
app = typer.Typer(
50+
help="My ProteinGym model",
51+
add_completion=True,
52+
)
53+
54+
55+
@app.command()
56+
def train(
57+
dataset_reference: Annotated[
58+
Path,
59+
typer.Option(
60+
help="Path to the archived dataset",
61+
),
62+
],
63+
model_reference: Annotated[
64+
Path,
65+
typer.Option(
66+
help="Path to the model card file",
67+
),
68+
],
69+
) -> Path:
70+
71+
dataset = Dataset.from_path(dataset_path)
72+
manifest = Manifest.from_path(model_card_path)
73+
74+
# For a supervised model
75+
model = load(manifest)
76+
df = infer(dataset, manifest, model)
77+
df.to_csv(...)
78+
79+
# For a zero-shot model
80+
model = train(dataset, manifest)
81+
df = infer(dataset, manifest, model)
82+
df.to_csv(...)
83+
84+
85+
if __name__ == "__main__":
86+
app()
87+
88+
```
89+
90+
## Suggested code structure
91+
92+
> [!NOTE]
93+
> Python examples below translates to other languages too.
94+
95+
In addition to the [**required** entrypoints](#entrypoints), we suggest the
96+
following code structure:
97+
98+
``` tree
99+
├── __main__.py
100+
├── model.py
101+
├── preprocess.py
102+
└── utils.py
103+
```
104+
105+
### `__main__.py`
106+
107+
The `__main__.py` contains the `train` entrypoint as shown above.
108+
The code loads the dataset and model (card) before passing it to the `load`, `train`
109+
or `infer` methods.
110+
111+
### `preprocess.py`
112+
113+
`preprocess.py` contains the data preprocessing code, functions like:
114+
115+
``` python
116+
def encode(data: np.ndarray) -> np.ndarray:
117+
"""Encode the data."""
118+
return encoded_data
119+
```
120+
121+
``` python
122+
def train_test_split(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
123+
"""Split the data."""
124+
return train_data, test_data
125+
```
126+
127+
### `model.py`
128+
129+
`model.py` contains the code related with model:
130+
131+
``` python
132+
def train(dataset: Dataset, manifest: Manifest) -> Any
133+
"""Train the model."""
134+
X, y = load_x_and_y(
135+
dataset=dataset,
136+
split="train",
137+
)
138+
139+
model = Model(manifest)
140+
model.fit(X, y)
141+
142+
return model
143+
```
144+
145+
``` python
146+
def load(manifest: Manifest) -> Any:
147+
"""Load the model."""
148+
model = Model.from_manifest(manifest)
149+
return model
150+
```
151+
152+
``` python
153+
def infer(dataset: Dataset, manifest: Manifest, model: Any) -> DataFrame:
154+
"""Infer predictions on the data."""
155+
X, y = load_x_and_y(
156+
dataset=dataset,
157+
split="test",
158+
)
159+
160+
predictions = model.predict(manifest, X)
161+
162+
df = DataFrame(predictions)
163+
164+
return df
165+
```
166+
167+
### `utils.py`
168+
169+
It contains the supporting methods from the original models' code to facilitate the `model.py`.
170+
171+
## Backends
172+
173+
This section details common logic per backend.
174+
175+
### SageMaker
176+
177+
When using SageMaker, the references point to S3 paths mounted to the (Docker)
178+
container. Containers are destroyed after running them, but the data can be
179+
safely persisted in the S3 buckets. These mounted paths are defined as below
180+
181+
```python
182+
class SageMakerPathLayout:
183+
"""SageMaker's paths layout."""
184+
185+
PREFIX: Path = Path("/opt/ml")
186+
"""All Sagemaker paths start with this prefix."""
187+
188+
TRAINING_JOB_PATH: Path = PREFIX / "input" / "data" / "training" / "dataset.zip"
189+
"""Path to training data."""
190+
191+
MANIFEST_PATH: Path = PREFIX / "input" / "data" / "manifest" / "manifest.toml"
192+
"""Path to the model manifest."""
193+
194+
OUTPUT_PATH = PREFIX / "output"
195+
"""Path to the output, such as the result data frames."""
196+
```
197+
198+
For example, to persist the score for a given dataset and model as csv:
199+
200+
``` python
201+
scores: pd.DataFrame
202+
scores.to_csv(
203+
f"{SageMakerTrainingJobPath.OUTPUT_PATH}/{dataset.name}_{model.name}.csv",
204+
index=False,
205+
)
206+
```

models/esm/src/pg2_model_esm/__main__.py

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from typing import Annotated
22
from pathlib import Path
3-
import torch
43
import typer
54
from rich.console import Console
65
from pg2_dataset.dataset import Dataset
7-
from tqdm import tqdm
8-
from esm import pretrained
9-
from pg2_model_esm.utils import compute_pppl, label_row
6+
from pg2_model_esm.model import load, infer
107
from pg2_benchmark.manifest import Manifest
118

129

@@ -15,19 +12,15 @@
1512
add_completion=True,
1613
)
1714

18-
err_console = Console(stderr=True)
1915
console = Console()
2016

2117

2218
class SageMakerTrainingJobPath:
2319
PREFIX = Path("/opt/ml")
2420
TRAINING_JOB_PATH = PREFIX / "input" / "data" / "training" / "dataset.zip"
2521
MANIFEST_PATH = PREFIX / "input" / "data" / "manifest" / "manifest.toml"
26-
PARAMS_PATH = PREFIX / "input" / "config" / "hyperparameters.json"
2722
OUTPUT_PATH = PREFIX / "model"
2823

29-
MODEL_PATH = Path("/model.pkl")
30-
3124

3225
@app.command()
3326
def train(
@@ -47,103 +40,17 @@ def train(
4740
console.print(f"Loading {dataset_file} and {model_toml_file}...")
4841

4942
dataset = Dataset.from_path(dataset_file)
50-
51-
assays = dataset.assays.meta.assays
52-
targets = list(dataset.assays.meta.assays.keys())
53-
54-
sequence = assays[targets[0]].constants["sequence"]
55-
mutation_col = assays[targets[0]].constants["mutation_col"]
56-
57-
df = dataset.assays.data_frame
58-
59-
console.print(f"Loaded {len(df)} records.")
60-
6143
manifest = Manifest.from_path(model_toml_file)
6244

63-
model, alphabet = pretrained.load_model_and_alphabet(
64-
manifest.hyper_params["location"]
65-
)
66-
model.eval()
45+
model, alphabet = load(manifest)
6746

68-
console.print(
69-
f"Loaded the model from {manifest.hyper_params['location']} with scoring strategy {manifest.hyper_params['scoring_strategy']}."
47+
df = infer(
48+
dataset=dataset,
49+
manifest=manifest,
50+
model=model,
51+
alphabet=alphabet,
7052
)
7153

72-
if torch.cuda.is_available() and not manifest.hyper_params["nogpu"]:
73-
model = model.cuda()
74-
print("Transferred model to GPU")
75-
76-
batch_converter = alphabet.get_batch_converter()
77-
78-
data = [
79-
("protein1", sequence),
80-
]
81-
82-
batch_labels, batch_strs, batch_tokens = batch_converter(data)
83-
84-
match manifest.hyper_params["scoring_strategy"]:
85-
case "wt-marginals":
86-
with torch.no_grad():
87-
token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)
88-
89-
df["pred"] = df.apply(
90-
lambda row: label_row(
91-
row[mutation_col],
92-
sequence,
93-
token_probs,
94-
alphabet,
95-
manifest.hyper_params["offset_idx"],
96-
),
97-
axis=1,
98-
)
99-
100-
case "masked-marginals":
101-
all_token_probs = []
102-
103-
for i in tqdm(range(batch_tokens.size(1))):
104-
batch_tokens_masked = batch_tokens.clone()
105-
batch_tokens_masked[0, i] = alphabet.mask_idx
106-
107-
with torch.no_grad():
108-
token_probs = torch.log_softmax(
109-
model(batch_tokens_masked)["logits"], dim=-1
110-
)
111-
112-
all_token_probs.append(token_probs[:, i]) # vocab size
113-
114-
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
115-
116-
df["pred"] = df.apply(
117-
lambda row: label_row(
118-
row[mutation_col],
119-
sequence,
120-
token_probs,
121-
alphabet,
122-
manifest.hyper_params["offset_idx"],
123-
),
124-
axis=1,
125-
)
126-
127-
case "pseudo-ppl":
128-
tqdm.pandas()
129-
130-
df["pred"] = df.progress_apply(
131-
lambda row: compute_pppl(
132-
row[mutation_col],
133-
sequence,
134-
model,
135-
alphabet,
136-
manifest.hyper_params["offset_idx"],
137-
),
138-
axis=1,
139-
)
140-
141-
case _:
142-
err_console.print(
143-
f"Error: Invalid scoring strategy: {manifest.hyper_params['scoring_strategy']}"
144-
)
145-
146-
df.rename(columns={targets[0]: "test"}, inplace=True)
14754
df.to_csv(
14855
f"{SageMakerTrainingJobPath.OUTPUT_PATH}/{dataset.name}_{manifest.name}.csv",
14956
index=False,
@@ -152,7 +59,6 @@ def train(
15259
console.print(
15360
f"Saved the metrics in CSV in {SageMakerTrainingJobPath.OUTPUT_PATH}/{dataset.name}_{manifest.name}.csv"
15461
)
155-
console.print("Done.")
15662

15763

15864
@app.command()

0 commit comments

Comments
 (0)