Skip to content

Commit c898000

Browse files
committed
Update method name from predict to infer
1 parent b998100 commit c898000

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

models/README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ manifest = Manifest.from_path(model_toml_file)
2727

2828
Finally, inside this `train` method:
2929

30-
* For a **supervised** model, like [esm](esm/), it calls `load_model` and `predict_model` in order:
31-
* `load_model` uses `manifest` as input, and returns a model object as output.
32-
* `predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
30+
* For a **supervised** model, like [esm](esm/), it calls `load` and `infer` in order:
31+
* `load` uses `manifest` as input, and returns a model object as output.
32+
* `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
3333

34-
* For a **zero-shot** model, like [pls](pls/), it calls `train_model` and `predict_model` in order:
35-
* `train_model` uses `dataset` and `manifest` as input, and returns a model object as output.
36-
* `predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
34+
* For a **zero-shot** model, like [pls](pls/), it calls `train` and `infer` in order:
35+
* `train` uses `dataset` and `manifest` as input, and returns a model object as output.
36+
* `infer` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
3737

3838
The result data frame is saved on the disk in the local environment and stored in AWS S3 in the cloud environment. After the container is destroyed, the result data frame is persisted for the later metric calculation.
3939

@@ -73,12 +73,12 @@ def train(
7373

7474
# For a supervised model
7575
model = load(manifest)
76-
df = predict(dataset, manifest, model)
76+
df = infer(dataset, manifest, model)
7777
df.to_csv(...)
7878

7979
# For a zero-shot model
8080
model = train(dataset, manifest)
81-
df = predict(dataset, manifest, model)
81+
df = infer(dataset, manifest, model)
8282
df.to_csv(...)
8383

8484

@@ -105,8 +105,8 @@ following code structure:
105105
### `__main__.py`
106106

107107
The `__main__.py` contains the `train` entrypoint as shown above.
108-
The code loads the dataset and model (card) before passing it to the `load_model`, `train_model`
109-
or `predict_model` methods.
108+
The code loads the dataset and model (card) before passing it to the `load`, `train`
109+
or `infer` methods.
110110

111111
### `preprocess.py`
112112

@@ -150,7 +150,7 @@ def load(manifest: Manifest) -> Any:
150150
```
151151

152152
``` python
153-
def predict(dataset: Dataset, manifest: Manifest, model: Any) -> DataFrame:
153+
def infer(dataset: Dataset, manifest: Manifest, model: Any) -> DataFrame:
154154
"""Infer predictions on the data."""
155155
X, y = load_x_and_y(
156156
dataset=dataset,

models/esm/src/pg2_model_esm/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typer
44
from rich.console import Console
55
from pg2_dataset.dataset import Dataset
6-
from pg2_model_esm.model import load, predict
6+
from pg2_model_esm.model import load, infer
77
from pg2_benchmark.manifest import Manifest
88

99

@@ -44,7 +44,7 @@ def train(
4444

4545
model, alphabet = load(manifest)
4646

47-
df = predict(
47+
df = infer(
4848
dataset=dataset,
4949
manifest=manifest,
5050
model=model,

models/esm/src/pg2_model_esm/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def load(manifest: Manifest) -> tuple[torch.nn.Module, Alphabet]:
3737
return model, alphabet
3838

3939

40-
def predict(
40+
def infer(
4141
dataset: Dataset,
4242
manifest: Manifest,
4343
model: torch.nn.Module,

models/pls/src/pg2_model_pls/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33
from rich.console import Console
44
from pg2_dataset.dataset import Dataset
5-
from pg2_model_pls.model import train as train_model, predict
5+
from pg2_model_pls.model import train as train_model, infer
66
from pg2_benchmark.manifest import Manifest
77

88
import typer
@@ -47,7 +47,7 @@ def train(
4747
manifest=manifest,
4848
)
4949

50-
df = predict(
50+
df = infer(
5151
dataset=dataset,
5252
manifest=manifest,
5353
model=model,

models/pls/src/pg2_model_pls/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def train(
4646
return model
4747

4848

49-
def predict(
49+
def infer(
5050
dataset: Dataset,
5151
manifest: Manifest,
5252
model: PLSRegression,

0 commit comments

Comments
 (0)