3333import mlflow
3434from mlflow .data .pandas_dataset import from_pandas
3535from mlflow .models import infer_signature
36+ from transformers .models .tapas .modeling_tapas import flatten
37+
3638from .mixins import TorchModelTrainMixin
3739from .torch_models import LSTMTorchNN
3840from .dataset import TweetDataset
@@ -72,11 +74,11 @@ class BaseModelABC(ABC):
7274 dataset_class = None
7375 tokenizer_class = None
7476 batch_size = 32
75- #artifact_uri = "file:///app/mlruns"
7677
77- def __init__ (self , dataset : pd .DataFrame ):
78+ def __init__ (self , dataset : pd .DataFrame , tracking : bool = True ):
79+ self .tracking = tracking
7880 self .dataset = None
79- if dataset :
81+ if dataset is not None :
8082 self .original_dataset = dataset
8183 self .x_train , self .x_test , self .x_val , self .y_train , self .y_test , self .y_val = split_data (dataset )
8284 self .name = self .__class__ .__name__
@@ -112,10 +114,12 @@ def init_items(self):
112114 """
113115 self .model = None
114116 self .tokenizer = None
117+ self .tokenizer .fit_transform (self .x_train )
115118
116119 def init_mlflow (self , name :str = "" ):
117- self .run = mlflow .start_run (run_name = name if name else self .name )
118- self .run_id = self .run .info .run_id
120+ if self .tracking :
121+ self .run = mlflow .start_run (run_name = name if name else self .name )
122+ self .run_id = self .run .info .run_id
119123
120124 def load_checkpoint (self ) -> object :
121125 """
@@ -135,11 +139,13 @@ def preprocessing(self, data):
135139
136140 def confusion_matrix (self ):
137141 with NamedTemporaryFile (suffix = ".png" ) as f :
138- conf_mat = confusion_matrix (self .y_train , self .predict (list (self .x_train )))
142+ y_pred = self .predict (list (self .x_train ), flatten = True )
143+ conf_mat = confusion_matrix (self .y_train , y_pred )
139144 group_names = ["True Neg" , "False Pos" , "False Neg" , "True Pos" ]
140145 group_counts = [f"{ value : 0.0f} " for value in conf_mat .flatten ()]
141146 group_percentages = [f"{ value :.2%} " for value in conf_mat .flatten () / np .sum (conf_mat )]
142147 labels = [f"{ v1 } \n { v2 } \n { v3 } " for v1 , v2 , v3 in zip (group_names , group_counts , group_percentages )]
148+ print (labels )
143149 labels = np .asarray (labels ).reshape (2 , 2 )
144150 sns .heatmap (
145151 conf_mat ,
@@ -151,46 +157,49 @@ def confusion_matrix(self):
151157 plt .ylabel ("Cluster réels" )
152158 plt .savefig (f .name )
153159 plt .close ()
154- mlflow .log_artifact (f .name , "confusion_matrix.png" )
160+ if self .tracking :
161+ mlflow .log_artifact (f .name , "confusion_matrix.png" )
155162
156163 def train (self ):
157164 """
158165 Train the model here
159166 """
160167 self .save ()
161- mlflow .set_tag ("model_type" , self .name )
162- self .log_metrics ()
168+ if self .tracking :
169+ mlflow .set_tag ("model_type" , self .name )
170+ self .log_metrics ()
163171 self .confusion_matrix ()
164- mlflow .log_artifact (self .checkpoint )
165- signature = infer_signature (self .x_train , self .predict (self .x_train ))
166- dataset = from_pandas (self .original_dataset .loc [self .x_train .index ], source = "local" )
167- mlflow .log_input (dataset , context = "tweet-dataset" )
168- if isinstance (self .model , PreTrainedModel ):
169- mlflow .transformers .log_model (
170- transformers_model = self .checkpoint ,
171- artifact_path = self .name ,
172- task = "text-classification" , # important !
173- tokenizer = self .tokenizer ,
174- signature = signature ,
175- registered_model_name = f"{ self .name } -quickstart"
176- )
177- elif isinstance (self .model , lgm .LGBMClassifier ):
178- mlflow .lightgbm .log_model (
179- lgb_model = self .model ,
180- artifact_path = self .name ,
181- signature = signature ,
182- registered_model_name = f"{ self .name } -quickstart" ,
183- )
184- else :
185- mlflow .sklearn .log_model (
186- sk_model = self .model ,
187- artifact_path = self .name ,
188- signature = signature ,
189- registered_model_name = f"{ self .name } -quickstart" ,
190- )
191- mlflow .end_run ()
172+ if self .tracking :
173+ mlflow .log_artifact (self .checkpoint )
174+ signature = infer_signature (self .x_train , self .predict (self .x_train ))
175+ dataset = from_pandas (self .original_dataset .loc [self .x_train .index ], source = "local" )
176+ mlflow .log_input (dataset , context = "tweet-dataset" )
177+ if isinstance (self .model , PreTrainedModel ):
178+ mlflow .transformers .log_model (
179+ transformers_model = self .checkpoint ,
180+ artifact_path = self .name ,
181+ task = "text-classification" , # important !
182+ tokenizer = self .tokenizer ,
183+ signature = signature ,
184+ registered_model_name = f"{ self .name } -quickstart"
185+ )
186+ elif isinstance (self .model , lgm .LGBMClassifier ):
187+ mlflow .lightgbm .log_model (
188+ lgb_model = self .model ,
189+ artifact_path = self .name ,
190+ signature = signature ,
191+ registered_model_name = f"{ self .name } -quickstart" ,
192+ )
193+ else :
194+ mlflow .sklearn .log_model (
195+ sk_model = self .model ,
196+ artifact_path = self .name ,
197+ signature = signature ,
198+ registered_model_name = f"{ self .name } -quickstart" ,
199+ )
200+ mlflow .end_run ()
192201
193- def predict (self , x : Union [pd .Series , np .ndarray ]):
202+ def predict (self , x : Union [pd .Series , np .ndarray ], flatten : bool = True ):
194203 """
195204 Predict the sentiment of the input data
196205 """
@@ -200,7 +209,6 @@ def predict(self, x: Union[pd.Series, np.ndarray]):
200209class SklearnBaseModel (BaseModelABC ):
201210 def log_metrics (self ):
202211 super ().log_metrics ()
203- #mlflow.sklearn.log_model(self.model, self.name)
204212 mlflow .log_params (self .model .get_params ())
205213
206214class TorchBaseModel (TorchModelTrainMixin , BaseModelABC ):
@@ -209,13 +217,13 @@ class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
209217 """
210218 distributed = False
211219
212- def __init__ (self , dataset : pd .DataFrame ):
220+ def __init__ (self , dataset : pd .DataFrame , tracking : bool = True ):
213221 if self .distributed :
214222 dist .init_process_group ("nccl" )
215223 if dist .is_initialized ():
216224 self .local_rank = dist .get_rank ()
217225 torch .cuda .set_device (self .local_rank )
218- super ().__init__ (dataset )
226+ super ().__init__ (dataset , tracking )
219227 if dist .is_initialized ():
220228 self .dataloader , self .sampler = self .get_ddp_dataloader ()
221229 logger .info (f"Rank { dist .get_rank ()} using DDP" )
@@ -257,6 +265,8 @@ def init_items(self):
257265 self .tokenizer = self .tokenizer_class (
258266 max_features = 1000 , ngram_range = (1 , 2 ), binary = True
259267 )
268+ self .tokenizer .fit_transform (self .x_train )
269+
260270
261271 def train (self ):
262272 """
@@ -297,6 +307,7 @@ def init_items(self):
297307 self .tokenizer = self .tokenizer_class (
298308 max_features = 1000 , min_df = 2 , max_df = 0.95
299309 )
310+ self .tokenizer .fit_transform (self .x_train )
300311
301312 def clean (self , tweet ):
302313 translator = str .maketrans ('' ,'' , string .punctuation )
@@ -348,10 +359,11 @@ def init_items(self):
348359 """
349360 self .model = LogisticRegression (max_iter = 1000 ,
350361 C = 1.7279373898388395 ,
351- penalty = 'l1 ' ,
362+ penalty = 'l2 ' ,
352363 n_jobs = 4 ,
353364 verbose = True )
354365 self .tokenizer = self .tokenizer_class ()
366+ self .tokenizer .fit_transform (self .x_train )
355367
356368 def objective (self , tokens , params ):
357369 with mlflow .start_run (nested = True ):
@@ -431,6 +443,7 @@ def load_checkpoint(self):
431443 num_training_steps = total_steps
432444 )
433445 self .criterion = torch .nn .CrossEntropyLoss ()
446+ self .model .to (self .device )
434447
435448 def save (self ):
436449 if self .distributed and dist .is_initialized ():
@@ -439,7 +452,7 @@ def save(self):
439452 self .model .save_pretrained (self .checkpoint )
440453 self .tokenizer .save_pretrained (self .checkpoint )
441454
442- def predict (self , x : list ):
455+ def predict (self , x : list , flatten : bool = False ):
443456 predicted_class = []
444457 for i in range (0 , len (x ), self .batch_size ):
445458 inputs = self .preprocessing (x [i :i + self .batch_size ])
@@ -448,10 +461,13 @@ def predict(self, x: list):
448461 outputs = self .model (** inputs )
449462 probs = F .softmax (outputs .logits , dim = 1 )
450463 confidence , categorie = probs .max (dim = 1 )
451- predicted_class .extend ([{'prediction' : categorie .item (),
452- 'confidence' : confidence .item ()}
453- for confidence , categorie in zip (confidence , categorie )
454- ])
464+ if not flatten :
465+ predicted_class .extend ([{'prediction' : categorie .item (),
466+ 'confidence' : confidence .item ()}
467+ for confidence , categorie in zip (confidence , categorie )
468+ ])
469+ else :
470+ predicted_class .extend ([c .item () for c in categorie ])
455471 return predicted_class
456472
457473class RobertaModel (BertModel ):
@@ -474,7 +490,7 @@ class LSTMModel(TorchBaseModel):
474490 out_features = 1
475491 lr = 1e-4
476492 device = DEVICE
477- # torch.nn.CrossEntropyLoss()
493+ format_labels_as_long : bool = False
478494
479495 @property
480496 def get_metrics (self ) -> dict :
@@ -540,7 +556,7 @@ def save(self):
540556 self .tokenizer .save_pretrained (self .checkpoint )
541557 torch .save (self .model .state_dict (), self .checkpoint + "/model.pth" )
542558
543- def predict (self , x ):
559+ def predict (self , x , flatten : bool = True ):
544560 predicted_class = []
545561 for i in range (0 , len (x ), self .batch_size ):
546562 inputs = self .preprocessing (x [i :i + self .batch_size ])
0 commit comments