|
| 1 | +# Model |
| 2 | + |
| 3 | +This README details how to add a model to the benchmark. |
| 4 | + |
| 5 | +## Entrypoints |
| 6 | + |
| 7 | +A model requires only one entrypoint: the `train` method, which you can referecen from below two models: |
| 8 | + |
| 9 | +* [esm/src/pg2_model_esm/__main__.py](esm/src/pg2_model_esm/__main__.py) |
| 10 | +* [pls/src/pg2_model_pls/__main__.py](pls/src/pg2_model_pls/__main__.py) |
| 11 | + |
| 12 | +Both **supervised** models and **zero-shot** models call this `train` method, because it is the glue method to glue the packages: `pg2-dataset`, `pg2-benchmark` and the models' original source code together. The method is named `train`, because for SageMaker, it looks for the `train` method as a entrypoint, thus it becomes the common method for both environments: local and AWS. |
| 13 | + |
| 14 | +This entrypoint expects a reference to a dataset, e.g., loaded by `pg2-dataset`: |
| 15 | + |
| 16 | +```python |
| 17 | +from pg2_dataset.dataset import Dataset |
| 18 | +dataset = Dataset.from_path(dataset_file) |
| 19 | +``` |
| 20 | + |
| 21 | +Additionally, this entrypoint also expects a reference to a model card, e.g., loaded by `pg2-benchmark`: |
| 22 | + |
| 23 | +``` |
| 24 | +from pg2_benchmark.manifest import Manifest |
| 25 | +manifest = Manifest.from_path(model_toml_file) |
| 26 | +``` |
| 27 | + |
| 28 | +Finally, inside this `train` method: |
| 29 | + |
| 30 | +* For a **supervised** model, like [esm](esm/), it calls `load` and `infer` in order: |
| 31 | + * `load` uses `manifest` as input, and returns a model object as output. |
| 32 | + * `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output. |
| 33 | + |
| 34 | +* For a **zero-shot** model, like [pls](pls/), it calls `train` and `infer` in order: |
| 35 | + * `train` uses `dataset` and `manifest` as input, and returns a model object as output. |
| 36 | + * `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output. |
| 37 | + |
| 38 | +The result data frame is saved on the disk in the local environment and stored in AWS S3 in the cloud environment. After the container is destroyed, the result data frame is persisted for the later metric calculation. |
| 39 | + |
| 40 | +For reference, below an example Python implementation with `typer`: |
| 41 | + |
| 42 | +``` python |
| 43 | +# In `__main__.py` |
| 44 | +import typer |
| 45 | +from pg2_dataset import Dataset |
| 46 | +from pg2_benchmark import Manifest |
| 47 | + |
| 48 | + |
| 49 | +app = typer.Typer( |
| 50 | + help="My ProteinGym model", |
| 51 | + add_completion=True, |
| 52 | +) |
| 53 | + |
| 54 | + |
| 55 | +@app.command() |
| 56 | +def train( |
| 57 | + dataset_reference: Annotated[ |
| 58 | + Path, |
| 59 | + typer.Option( |
| 60 | + help="Path to the archived dataset", |
| 61 | + ), |
| 62 | + ], |
| 63 | + model_reference: Annotated[ |
| 64 | + Path, |
| 65 | + typer.Option( |
| 66 | + help="Path to the model card file", |
| 67 | + ), |
| 68 | + ], |
| 69 | +) -> Path: |
| 70 | + |
| 71 | + dataset = Dataset.from_path(dataset_path) |
| 72 | + manifest = Manifest.from_path(model_card_path) |
| 73 | + |
| 74 | + # For a supervised model |
| 75 | + model = load(manifest) |
| 76 | + df = infer(dataset, manifest, model) |
| 77 | + df.to_csv(...) |
| 78 | + |
| 79 | + # For a zero-shot model |
| 80 | + model = train(dataset, manifest) |
| 81 | + df = infer(dataset, manifest, model) |
| 82 | + df.to_csv(...) |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + app() |
| 87 | + |
| 88 | +``` |
| 89 | + |
| 90 | +## Suggested code structure |
| 91 | + |
| 92 | +> [!NOTE] |
| 93 | +> Python examples below translates to other languages too. |
| 94 | +
|
| 95 | +In addition to the [**required** entrypoints](#entrypoints), we suggest the |
| 96 | +following code structure: |
| 97 | + |
| 98 | +``` tree |
| 99 | +├── __main__.py |
| 100 | +├── model.py |
| 101 | +├── preprocess.py |
| 102 | +└── utils.py |
| 103 | +``` |
| 104 | + |
| 105 | +### `__main__.py` |
| 106 | + |
| 107 | +The `__main__.py` contains the `train` entrypoint as shown above. |
| 108 | +The code loads the dataset and model (card) before passing it to the `load`, `train` |
| 109 | +or `infer` methods. |
| 110 | + |
| 111 | +### `preprocess.py` |
| 112 | + |
| 113 | +`preprocess.py` contains the data preprocessing code, functions like: |
| 114 | + |
| 115 | +``` python |
| 116 | +def encode(data: np.ndarray) -> np.ndarray: |
| 117 | + """Encode the data.""" |
| 118 | + return encoded_data |
| 119 | +``` |
| 120 | + |
| 121 | +``` python |
| 122 | +def train_test_split(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
| 123 | + """Split the data.""" |
| 124 | + return train_data, test_data |
| 125 | +``` |
| 126 | + |
| 127 | +### `model.py` |
| 128 | + |
| 129 | +`model.py` contains the code related with model: |
| 130 | + |
| 131 | +``` python |
| 132 | +def train(dataset: Dataset, manifest: Manifest) -> Any |
| 133 | + """Train the model.""" |
| 134 | + X, y = load_x_and_y( |
| 135 | + dataset=dataset, |
| 136 | + split="train", |
| 137 | + ) |
| 138 | + |
| 139 | + model = Model(manifest) |
| 140 | + model.fit(X, y) |
| 141 | + |
| 142 | + return model |
| 143 | +``` |
| 144 | + |
| 145 | +``` python |
| 146 | +def load(manifest: Manifest) -> Any: |
| 147 | + """Load the model.""" |
| 148 | + model = Model.from_manifest(manifest) |
| 149 | + return model |
| 150 | +``` |
| 151 | + |
| 152 | +``` python |
| 153 | +def infer(dataset: Dataset, manifest: Manifest, model: Any) -> DataFrame: |
| 154 | + """Infer predictions on the data.""" |
| 155 | + X, y = load_x_and_y( |
| 156 | + dataset=dataset, |
| 157 | + split="test", |
| 158 | + ) |
| 159 | + |
| 160 | + predictions = model.predict(manifest, X) |
| 161 | + |
| 162 | + df = DataFrame(predictions) |
| 163 | + |
| 164 | + return df |
| 165 | +``` |
| 166 | + |
| 167 | +### `utils.py` |
| 168 | + |
| 169 | +It contains the supporting methods from the original models' code to facilitate the `model.py`. |
| 170 | + |
| 171 | +## Backends |
| 172 | + |
| 173 | +This section details common logic per backend. |
| 174 | + |
| 175 | +### SageMaker |
| 176 | + |
| 177 | +When using SageMaker, the references point to S3 paths mounted to the (Docker) |
| 178 | +container. Containers are destroyed after running them, but the data can be |
| 179 | +safely persisted in the S3 buckets. These mounted paths are defined as below |
| 180 | + |
| 181 | +```python |
| 182 | +class SageMakerPathLayout: |
| 183 | + """SageMaker's paths layout.""" |
| 184 | + |
| 185 | + PREFIX: Path = Path("/opt/ml") |
| 186 | + """All Sagemaker paths start with this prefix.""" |
| 187 | + |
| 188 | + TRAINING_JOB_PATH: Path = PREFIX / "input" / "data" / "training" / "dataset.zip" |
| 189 | + """Path to training data.""" |
| 190 | + |
| 191 | + MANIFEST_PATH: Path = PREFIX / "input" / "data" / "manifest" / "manifest.toml" |
| 192 | + """Path to the model manifest.""" |
| 193 | + |
| 194 | + OUTPUT_PATH = PREFIX / "output" |
| 195 | + """Path to the output, such as the result data frames.""" |
| 196 | +``` |
| 197 | + |
| 198 | +For example, to persist the score for a given dataset and model as csv: |
| 199 | + |
| 200 | +``` python |
| 201 | +scores: pd.DataFrame |
| 202 | +scores.to_csv( |
| 203 | + f"{SageMakerTrainingJobPath.OUTPUT_PATH}/{dataset.name}_{model.name}.csv", |
| 204 | + index=False, |
| 205 | +) |
| 206 | +``` |
0 commit comments