Skip to content

Commit b998100

Browse files
committed
Update source code according to README.md APIs
1 parent 92e18a4 commit b998100

File tree

10 files changed

+228
-271
lines changed

10 files changed

+228
-271
lines changed

models/README.md

Lines changed: 72 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,40 @@ This README details how to add a model to the benchmark.
44

55
## Entrypoints
66

7-
A model requires the following entrypoints: `train` and `predict`:
7+
A model requires only one entrypoint: the `train` method, which you can referecen from below two models:
88

9-
The `train` entrypoint is only required for **supervised** models.
10-
The `predict` entrypoint is required for all models.
9+
* [esm/src/pg2_model_esm/__main__.py](esm/src/pg2_model_esm/__main__.py)
10+
* [pls/src/pg2_model_pls/__main__.py](pls/src/pg2_model_pls/__main__.py)
1111

12-
Both entrypoints expect a reference to a dataset: `dataset_reference`.
13-
Additionally, the `train` entrypoint expects a reference to the model card
14-
and the `predict` entrypoint expects a reference to the peristed model:
15-
`model_card_reference` and `model_reference`, respectively.
12+
Both **supervised** models and **zero-shot** models call this `train` method, because it is the glue method to glue the packages: `pg2-dataset`, `pg2-benchmark` and the models' original source code together. The method is named `train`, because for SageMaker, it looks for the `train` method as a entrypoint, thus it becomes the common method for both environments: local and AWS.
1613

17-
Finally, the `train` entrypoint outputs the model reference, which is the input
18-
for the `predict` entrypoint next to the dataset. The `predict` entrypoints
19-
outputs the inferred predictions:
14+
This entrypoint expects a reference to a dataset, e.g., loaded by `pg2-dataset`:
2015

21-
From the commandline these entrypoints interact as follows:
16+
```python
17+
from pg2_dataset.dataset import Dataset
18+
dataset = Dataset.from_path(dataset_file)
19+
```
20+
21+
Additionally, this entrypoint also expects a reference to a model card, e.g., loaded by `pg2-benchmark`:
2222

23-
``` bash
24-
$ train ./path/to/dataset_train.pgdata ./path/to/model_card.md
25-
./path/to/model.pickle
26-
$ predict ./path/to/dataset_validate.pgdata ./path/to/model.pickle
27-
[0.8, 0.5, ..., .04]
23+
```
24+
from pg2_benchmark.manifest import Manifest
25+
manifest = Manifest.from_path(model_toml_file)
2826
```
2927

30-
For reference, below an example Python implementation with `typer`:
28+
Finally, inside this `train` method:
3129

30+
* For a **supervised** model, like [esm](esm/), it calls `load_model` and `predict_model` in order:
31+
* `load_model` uses `manifest` as input, and returns a model object as output.
32+
* `predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
33+
34+
* For a **zero-shot** model, like [pls](pls/), it calls `train_model` and `predict_model` in order:
35+
* `train_model` uses `dataset` and `manifest` as input, and returns a model object as output.
36+
* `predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
37+
38+
The result data frame is saved on the disk in the local environment and stored in AWS S3 in the cloud environment. After the container is destroyed, the result data frame is persisted for the later metric calculation.
39+
40+
For reference, below an example Python implementation with `typer`:
3241

3342
``` python
3443
# In `__main__.py`
@@ -58,52 +67,19 @@ def train(
5867
),
5968
],
6069
) -> Path:
61-
"""Train the model on the dataset.
62-
63-
Args:
64-
dataset_reference (Path) : Path to the archived dataset.
65-
model_reference (Path) : Path to the model card file.
66-
67-
Returns:
68-
Path : The trained and persisted model.
69-
"""
70+
7071
dataset = Dataset.from_path(dataset_path)
7172
manifest = Manifest.from_path(model_card_path)
7273

73-
# Train the model below
74-
model_reference = ...
75-
return model_reference
76-
77-
78-
def predict(
79-
dataset_reference: Annotated[
80-
Path,
81-
typer.Option(
82-
help="Path to the archived dataset",
83-
),
84-
],
85-
model_reference: Annotated[
86-
Path,
87-
typer.Option(
88-
help="Path to the model file",
89-
),
90-
],
91-
) -> Iterable[float]:
92-
"""Predict (aka infer) given the dataset and the model.
93-
94-
Args:
95-
dataset_reference (Path) : Path to the archived dataset.
96-
model_reference (Path) : Path to the persisted and trained model file.
97-
98-
Returns:
99-
Iterable[float] : The predictions.
100-
"""
101-
dataset = Dataset.from_path(dataset_path)
102-
model = pickle.load(model_reference)
74+
# For a supervised model
75+
model = load(manifest)
76+
df = predict(dataset, manifest, model)
77+
df.to_csv(...)
10378

104-
# Predict the model below
105-
predictions = ...
106-
return predictions
79+
# For a zero-shot model
80+
model = train(dataset, manifest)
81+
df = predict(dataset, manifest, model)
82+
df.to_csv(...)
10783

10884

10985
if __name__ == "__main__":
@@ -121,18 +97,18 @@ following code structure:
12197

12298
``` tree
12399
├── __main__.py
124-
├── predict.py # For supervised models only
100+
├── model.py
125101
├── preprocess.py
126-
└── train.py
102+
└── utils.py
127103
```
128104

129105
### `__main__.py`
130106

131-
The `__main__.py` contains the `train` and `predict` entrypoints as shown above.
132-
The code loads the dataset and model (card) before passing it to the `train_model`
133-
or `predict_model` methods after preprocessing.
107+
The `__main__.py` contains the `train` entrypoint as shown above.
108+
The code loads the dataset and model (card) before passing it to the `load_model`, `train_model`
109+
or `predict_model` methods.
134110

135-
### `preprocess.py
111+
### `preprocess.py`
136112

137113
`preprocess.py` contains the data preprocessing code, functions like:
138114

@@ -148,35 +124,50 @@ def train_test_split(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
148124
return train_data, test_data
149125
```
150126

151-
### `train.py`
127+
### `model.py`
152128

153-
`train.py` contains the training code, functions like:
129+
`model.py` contains the code related with model:
154130

155131
``` python
156-
def train(model, Any, X: np.ndarray, y: np.array) -> Path
132+
def train(dataset: Dataset, manifest: Manifest) -> Any
157133
"""Train the model."""
134+
X, y = load_x_and_y(
135+
dataset=dataset,
136+
split="train",
137+
)
138+
139+
model = Model(manifest)
158140
model.fit(X, y)
159-
model_path = model.persist()
160-
return model_path
141+
142+
return model
161143
```
162144

163145
``` python
164-
def load(model_card_reference: Path) -> Any:
146+
def load(manifest: Manifest) -> Any:
165147
"""Load the model."""
166-
model_config = ModelCard.from_path(model_card_reference)
167-
model = Model.from_config(model_config)
148+
model = Model.from_manifest(manifest)
168149
return model
169150
```
170151

171-
### `predict.py`
172-
173152
``` python
174-
def predict(model: Any, X: np.ndarray) -> np.array:
153+
def predict(dataset: Dataset, manifest: Manifest, model: Any) -> DataFrame:
175154
"""Infer predictions on the data."""
176-
predictions = model.predict(X)
177-
return predictions
155+
X, y = load_x_and_y(
156+
dataset=dataset,
157+
split="test",
158+
)
159+
160+
predictions = model.predict(manifest, X)
161+
162+
df = DataFrame(predictions)
163+
164+
return df
178165
```
179166

167+
### `utils.py`
168+
169+
It contains the supporting methods from the original models' code to facilitate the `model.py`.
170+
180171
## Backends
181172

182173
This section details common logic per backend.
@@ -197,14 +188,11 @@ class SageMakerPathLayout:
197188
TRAINING_JOB_PATH: Path = PREFIX / "input" / "data" / "training" / "dataset.zip"
198189
"""Path to training data."""
199190

200-
MODEL_CARD_PATH: PAth = PREFIX / "input" / "config" / "model_card.md"
201-
"""Path to the model card."""
202-
203-
MODEL_PATH: Path = Path("/model.pkl")
204-
"""Model path."""
191+
MANIFEST_PATH: Path = PREFIX / "input" / "data" / "manifest" / "manifest.toml"
192+
"""Path to the model manifest."""
205193

206194
OUTPUT_PATH = PREFIX / "output"
207-
"""Output path"""
195+
"""Path to the output, such as the result data frames."""
208196
```
209197

210198
For example, to persist the score for a given dataset and model as csv:

models/esm/src/pg2_model_esm/__main__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typer
44
from rich.console import Console
55
from pg2_dataset.dataset import Dataset
6-
from pg2_model_esm.predict import load_model, predict_model
6+
from pg2_model_esm.model import load, predict
77
from pg2_benchmark.manifest import Manifest
88

99

@@ -12,19 +12,15 @@
1212
add_completion=True,
1313
)
1414

15-
err_console = Console(stderr=True)
1615
console = Console()
1716

1817

1918
class SageMakerTrainingJobPath:
2019
PREFIX = Path("/opt/ml")
2120
TRAINING_JOB_PATH = PREFIX / "input" / "data" / "training" / "dataset.zip"
2221
MANIFEST_PATH = PREFIX / "input" / "data" / "manifest" / "manifest.toml"
23-
PARAMS_PATH = PREFIX / "input" / "config" / "hyperparameters.json"
2422
OUTPUT_PATH = PREFIX / "model"
2523

26-
MODEL_PATH = Path("/model.pkl")
27-
2824

2925
@app.command()
3026
def train(
@@ -44,12 +40,11 @@ def train(
4440
console.print(f"Loading {dataset_file} and {model_toml_file}...")
4541

4642
dataset = Dataset.from_path(dataset_file)
47-
4843
manifest = Manifest.from_path(model_toml_file)
4944

50-
model, alphabet = load_model(manifest)
45+
model, alphabet = load(manifest)
5146

52-
df = predict_model(
47+
df = predict(
5348
dataset=dataset,
5449
manifest=manifest,
5550
model=model,
@@ -64,7 +59,6 @@ def train(
6459
console.print(
6560
f"Saved the metrics in CSV in {SageMakerTrainingJobPath.OUTPUT_PATH}/{dataset.name}_{manifest.name}.csv"
6661
)
67-
console.print("Done.")
6862

6963

7064
@app.command()
Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,28 @@
33
from tqdm import tqdm
44
import pandas as pd
55
from esm import pretrained
6-
from pg2_model_esm.utils import compute_pppl, label_row
7-
from pg2_benchmark.manifest import Manifest
86
from pg2_dataset.dataset import Dataset
7+
from pg2_benchmark.manifest import Manifest
98
from pg2_model_esm.preprocess import encode
9+
from pg2_model_esm.utils import compute_pppl, label_row
1010
import logging
1111

1212
logger = logging.getLogger(__name__)
1313

1414

15-
def load_model(manifest: Manifest) -> tuple[torch.nn.Module, Alphabet]:
15+
def load(manifest: Manifest) -> tuple[torch.nn.Module, Alphabet]:
16+
"""Load and configure an ESM model and its alphabet.
17+
18+
Loads a pretrained ESM model from the location specified in the manifest,
19+
sets it to evaluation mode, and optionally transfers it to GPU if available
20+
and not disabled.
21+
22+
Args:
23+
manifest: Configuration object containing model location and GPU settings
24+
25+
Returns:
26+
tuple: The loaded ESM model and its corresponding alphabet
27+
"""
1628
model, alphabet = pretrained.load_model_and_alphabet(
1729
manifest.hyper_params["location"]
1830
)
@@ -25,12 +37,31 @@ def load_model(manifest: Manifest) -> tuple[torch.nn.Module, Alphabet]:
2537
return model, alphabet
2638

2739

28-
def predict_model(
40+
def predict(
2941
dataset: Dataset,
3042
manifest: Manifest,
3143
model: torch.nn.Module,
3244
alphabet: Alphabet,
3345
) -> pd.DataFrame:
46+
"""Generate predictions for protein mutations using an ESM model.
47+
48+
Computes fitness scores for protein mutations using one of three scoring
49+
strategies: wild-type marginals, masked marginals, or pseudo-perplexity.
50+
The scoring strategy is determined by the manifest configuration.
51+
52+
Args:
53+
dataset: Dataset containing assay data with mutations to score
54+
manifest: Configuration object specifying scoring strategy and parameters
55+
model: The loaded ESM model for computing predictions
56+
alphabet: ESM alphabet for token encoding/decoding
57+
58+
Returns:
59+
pd.DataFrame: DataFrame with predictions added in 'pred' column and
60+
target column renamed to 'test'
61+
62+
Raises:
63+
ValueError: If an unrecognized scoring strategy is specified
64+
"""
3465
assays = dataset.assays.meta.assays
3566
targets = list(dataset.assays.meta.assays.keys())
3667

models/esm/src/pg2_model_esm/preprocess.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33

44

55
def encode(sequence: str, alphabet: Alphabet) -> torch.Tensor:
6+
"""Encode a protein sequence into tokens using the ESM alphabet.
7+
8+
Args:
9+
sequence: Protein sequence to encode
10+
alphabet: ESM alphabet for tokenization
11+
12+
Returns:
13+
Batch tokens tensor for the sequence
14+
"""
615
data = [
716
("protein1", sequence),
817
]

0 commit comments

Comments
 (0)