You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Both entrypoints expect a reference to a dataset: `dataset_reference`.
13
-
Additionally, the `train` entrypoint expects a reference to the model card
14
-
and the `predict` entrypoint expects a reference to the peristed model:
15
-
`model_card_reference` and `model_reference`, respectively.
12
+
Both **supervised** models and **zero-shot** models call this `train` method, because it is the glue method to glue the packages: `pg2-dataset`, `pg2-benchmark` and the models' original source code together. The method is named `train`, because for SageMaker, it looks for the `train` method as a entrypoint, thus it becomes the common method for both environments: local and AWS.
16
13
17
-
Finally, the `train` entrypoint outputs the model reference, which is the input
18
-
for the `predict` entrypoint next to the dataset. The `predict` entrypoints
19
-
outputs the inferred predictions:
14
+
This entrypoint expects a reference to a dataset, e.g., loaded by `pg2-dataset`:
20
15
21
-
From the commandline these entrypoints interact as follows:
16
+
```python
17
+
from pg2_dataset.dataset import Dataset
18
+
dataset = Dataset.from_path(dataset_file)
19
+
```
20
+
21
+
Additionally, this entrypoint also expects a reference to a model card, e.g., loaded by `pg2-benchmark`:
For reference, below an example Python implementation with `typer`:
28
+
Finally, inside this `train` method:
31
29
30
+
* For a **supervised** model, like [esm](esm/), it calls `load_model` and `predict_model` in order:
31
+
*`load_model` uses `manifest` as input, and returns a model object as output.
32
+
*`predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
33
+
34
+
* For a **zero-shot** model, like [pls](pls/), it calls `train_model` and `predict_model` in order:
35
+
*`train_model` uses `dataset` and `manifest` as input, and returns a model object as output.
36
+
*`predict_model` uses `dataset`, `manifest` and the model object as input, and returns the inferred predictions in a data frame as output.
37
+
38
+
The result data frame is saved on the disk in the local environment and stored in AWS S3 in the cloud environment. After the container is destroyed, the result data frame is persisted for the later metric calculation.
39
+
40
+
For reference, below an example Python implementation with `typer`:
32
41
33
42
```python
34
43
# In `__main__.py`
@@ -58,52 +67,19 @@ def train(
58
67
),
59
68
],
60
69
) -> Path:
61
-
"""Train the model on the dataset.
62
-
63
-
Args:
64
-
dataset_reference (Path) : Path to the archived dataset.
65
-
model_reference (Path) : Path to the model card file.
66
-
67
-
Returns:
68
-
Path : The trained and persisted model.
69
-
"""
70
+
70
71
dataset = Dataset.from_path(dataset_path)
71
72
manifest = Manifest.from_path(model_card_path)
72
73
73
-
# Train the model below
74
-
model_reference =...
75
-
return model_reference
76
-
77
-
78
-
defpredict(
79
-
dataset_reference: Annotated[
80
-
Path,
81
-
typer.Option(
82
-
help="Path to the archived dataset",
83
-
),
84
-
],
85
-
model_reference: Annotated[
86
-
Path,
87
-
typer.Option(
88
-
help="Path to the model file",
89
-
),
90
-
],
91
-
) -> Iterable[float]:
92
-
"""Predict (aka infer) given the dataset and the model.
93
-
94
-
Args:
95
-
dataset_reference (Path) : Path to the archived dataset.
96
-
model_reference (Path) : Path to the persisted and trained model file.
97
-
98
-
Returns:
99
-
Iterable[float] : The predictions.
100
-
"""
101
-
dataset = Dataset.from_path(dataset_path)
102
-
model = pickle.load(model_reference)
74
+
# For a supervised model
75
+
model = load(manifest)
76
+
df = predict(dataset, manifest, model)
77
+
df.to_csv(...)
103
78
104
-
# Predict the model below
105
-
predictions =...
106
-
return predictions
79
+
# For a zero-shot model
80
+
model = train(dataset, manifest)
81
+
df = predict(dataset, manifest, model)
82
+
df.to_csv(...)
107
83
108
84
109
85
if__name__=="__main__":
@@ -121,18 +97,18 @@ following code structure:
121
97
122
98
```tree
123
99
├── __main__.py
124
-
├── predict.py # For supervised models only
100
+
├── model.py
125
101
├── preprocess.py
126
-
└── train.py
102
+
└── utils.py
127
103
```
128
104
129
105
### `__main__.py`
130
106
131
-
The `__main__.py` contains the `train`and `predict` entrypoints as shown above.
132
-
The code loads the dataset and model (card) before passing it to the `train_model`
133
-
or `predict_model` methods after preprocessing.
107
+
The `__main__.py` contains the `train`entrypoint as shown above.
108
+
The code loads the dataset and model (card) before passing it to the `load_model`, `train_model`
109
+
or `predict_model` methods.
134
110
135
-
### `preprocess.py
111
+
### `preprocess.py`
136
112
137
113
`preprocess.py` contains the data preprocessing code, functions like:
0 commit comments