Skip to content

Commit 619da71

Browse files
Merge pull request #83 from ProteinGym2/fix/path-object
Remove default for Typer.Option for pathlib Path object
2 parents dad4a9a + 9ce9f2d commit 619da71

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

models/esm/src/pg2_model_esm/__main__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,15 @@ def train(
3434
dataset_file: Annotated[
3535
Path,
3636
typer.Option(
37-
default=SageMakerTrainingJobPath.TRAINING_JOB_PATH,
3837
help="Path to the dataset file",
3938
),
40-
],
39+
] = SageMakerTrainingJobPath.TRAINING_JOB_PATH,
4140
model_toml_file: Annotated[
4241
Path,
4342
typer.Option(
44-
default=SageMakerTrainingJobPath.MANIFEST_PATH,
4543
help="Path to the model TOML file",
4644
),
47-
],
45+
] = SageMakerTrainingJobPath.MANIFEST_PATH,
4846
):
4947
console.print(f"Loading {dataset_file} and {model_toml_file}...")
5048

models/pls/src/pg2_model_pls/__main__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,15 @@ def train(
3232
dataset_file: Annotated[
3333
Path,
3434
typer.Option(
35-
default=SageMakerTrainingJobPath.TRAINING_JOB_PATH,
3635
help="Path to the dataset file",
3736
),
38-
],
37+
] = SageMakerTrainingJobPath.TRAINING_JOB_PATH,
3938
model_toml_file: Annotated[
4039
Path,
4140
typer.Option(
42-
default=SageMakerTrainingJobPath.MANIFEST_PATH,
4341
help="Path to the model TOML file",
4442
),
45-
],
43+
] = SageMakerTrainingJobPath.MANIFEST_PATH,
4644
):
4745
console.print(f"Loading {dataset_file} and {model_toml_file}...")
4846

models/pls/src/pg2_model_pls/utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from typing import Any
33
import pickle
4+
from pathlib import Path
45
from sklearn.cross_decomposition import PLSRegression
56
from pg2_dataset.dataset import Dataset
67
from pg2_dataset.backends.assays import SPLIT_STRATEGY_MAPPING
@@ -125,8 +126,8 @@ def encode(spit_X: list[Any], hyper_params: dict[str, Any]) -> np.ndarray:
125126
def train_model(
126127
train_X: list[list[Any]],
127128
train_Y: list[Any],
128-
model_toml_file: str,
129-
model_path: str,
129+
model_toml_file: Path,
130+
model_path: Path,
130131
) -> None:
131132
"""Train a PLS regression model on encoded protein sequences and save it to disk.
132133
@@ -139,9 +140,9 @@ def train_model(
139140
Each inner list represents a single sequence.
140141
train_Y (list[Any]): Training target values corresponding to the sequences
141142
in train_X.
142-
model_toml_file (str): Path to the TOML configuration file containing model
143+
model_toml_file (Path): Path to the TOML configuration file containing model
143144
hyperparameters, including encoding parameters and n_components for PLS.
144-
model_path (str): File path where the trained model will be saved as a
145+
model_path (Path): File path where the trained model will be saved as a
145146
pickled object.
146147
147148
Returns:
@@ -167,16 +168,16 @@ def train_model(
167168
model = PLSRegression(manifest.hyper_params["n_components"])
168169
model.fit(encodings, train_Y)
169170

170-
with open(model_path, "wb") as file:
171+
with model_path.open(mode="wb") as file:
171172
pickle.dump(model, file)
172173

173174
logger.info(f"Saved the model to {model_path}")
174175

175176

176177
def predict_model(
177178
test_X: list[list[Any]],
178-
model_toml_file: str,
179-
model_path: str,
179+
model_toml_file: Path,
180+
model_path: Path,
180181
) -> list[Any]:
181182
"""Load a trained model and generate predictions on test sequences.
182183
@@ -187,9 +188,9 @@ def predict_model(
187188
Args:
188189
test_X (list[list[Any]]): Test feature data containing protein sequences.
189190
Each inner list represents a single sequence to predict on.
190-
model_toml_file (str): Path to the TOML configuration file containing model
191+
model_toml_file (Path): Path to the TOML configuration file containing model
191192
hyperparameters used for consistent encoding of test sequences.
192-
model_path (str): File path to the saved pickled model to load for prediction.
193+
model_path (Path): File path to the saved pickled model to load for prediction.
193194
194195
Returns:
195196
list[Any]: List of predictions corresponding to each sequence in test_X.
@@ -211,7 +212,7 @@ def predict_model(
211212
"""
212213
logger.info(f"Testing the model with {len(test_X)} records.")
213214

214-
with open(model_path, "rb") as file:
215+
with model_path.open(mode="rb") as file:
215216
model = pickle.load(file)
216217

217218
manifest = Manifest.from_path(model_toml_file)

0 commit comments

Comments
 (0)