11import numpy as np
22from typing import Any
33import pickle
4+ from pathlib import Path
45from sklearn .cross_decomposition import PLSRegression
56from pg2_dataset .dataset import Dataset
67from pg2_dataset .backends .assays import SPLIT_STRATEGY_MAPPING
@@ -125,8 +126,8 @@ def encode(spit_X: list[Any], hyper_params: dict[str, Any]) -> np.ndarray:
125126def train_model (
126127 train_X : list [list [Any ]],
127128 train_Y : list [Any ],
128- model_toml_file : str ,
129- model_path : str ,
129+ model_toml_file : Path ,
130+ model_path : Path ,
130131) -> None :
131132 """Train a PLS regression model on encoded protein sequences and save it to disk.
132133
@@ -139,9 +140,9 @@ def train_model(
139140 Each inner list represents a single sequence.
140141 train_Y (list[Any]): Training target values corresponding to the sequences
141142 in train_X.
142- model_toml_file (str ): Path to the TOML configuration file containing model
143+ model_toml_file (Path ): Path to the TOML configuration file containing model
143144 hyperparameters, including encoding parameters and n_components for PLS.
144- model_path (str ): File path where the trained model will be saved as a
145+ model_path (Path ): File path where the trained model will be saved as a
145146 pickled object.
146147
147148 Returns:
@@ -167,16 +168,16 @@ def train_model(
167168 model = PLSRegression (manifest .hyper_params ["n_components" ])
168169 model .fit (encodings , train_Y )
169170
170- with open (model_path , "wb" ) as file :
171+ with model_path . open (mode = "wb" ) as file :
171172 pickle .dump (model , file )
172173
173174 logger .info (f"Saved the model to { model_path } " )
174175
175176
176177def predict_model (
177178 test_X : list [list [Any ]],
178- model_toml_file : str ,
179- model_path : str ,
179+ model_toml_file : Path ,
180+ model_path : Path ,
180181) -> list [Any ]:
181182 """Load a trained model and generate predictions on test sequences.
182183
@@ -187,9 +188,9 @@ def predict_model(
187188 Args:
188189 test_X (list[list[Any]]): Test feature data containing protein sequences.
189190 Each inner list represents a single sequence to predict on.
190- model_toml_file (str ): Path to the TOML configuration file containing model
191+ model_toml_file (Path ): Path to the TOML configuration file containing model
191192 hyperparameters used for consistent encoding of test sequences.
192- model_path (str ): File path to the saved pickled model to load for prediction.
193+ model_path (Path ): File path to the saved pickled model to load for prediction.
193194
194195 Returns:
195196 list[Any]: List of predictions corresponding to each sequence in test_X.
@@ -211,7 +212,7 @@ def predict_model(
211212 """
212213 logger .info (f"Testing the model with { len (test_X )} records." )
213214
214- with open (model_path , "rb" ) as file :
215+ with model_path . open (mode = "rb" ) as file :
215216 model = pickle .load (file )
216217
217218 manifest = Manifest .from_path (model_toml_file )
0 commit comments