11import numpy as np
22import joblib
33import os
4+ import re
5+ import string
46from functools import partial
57from pathlib import Path
68from abc import ABC
1315from torch .utils .data import DataLoader
1416import torch .nn .functional as F
1517from sklearn .linear_model import LogisticRegression
16- from lightgbm import LGBMClassifier
18+ import lightgbm as lgm
1719from sklearn .model_selection import train_test_split , cross_val_score
1820from sklearn .feature_extraction .text import TfidfVectorizer , CountVectorizer
1921from sklearn .metrics import confusion_matrix , classification_report
@@ -112,7 +114,6 @@ def init_mlflow(self, name:str = ""):
112114 self .run = mlflow .start_run (run_name = name if name else self .name )
113115 self .run_id = self .run .info .run_id
114116
115-
116117 def load_checkpoint (self ) -> object :
117118 """
118119 Logic to load the model from a checkpoint or create a new one
@@ -147,7 +148,7 @@ def confusion_matrix(self):
147148 plt .ylabel ("Cluster réels" )
148149 plt .savefig (f .name )
149150 plt .close ()
150- mlflow .log_artifact (f .name ) # , "confusion_matrix.png")
151+ mlflow .log_artifact (f .name , "confusion_matrix.png" )
151152
152153 def train (self ):
153154 """
@@ -170,6 +171,13 @@ def train(self):
170171 signature = signature ,
171172 registered_model_name = f"{ self .name } -quickstart"
172173 )
174+ elif isinstance (self .model , lgm .LGBMClassifier ):
175+ mlflow .lightgbm .log_model (
176+ lgb_model = self .model ,
177+ artifact_path = self .name ,
178+ signature = signature ,
179+ registered_model_name = f"{ self .name } -quickstart" ,
180+ )
173181 else :
174182 mlflow .sklearn .log_model (
175183 sk_model = self .model ,
@@ -189,7 +197,7 @@ def predict(self, x: Union[pd.Series, np.ndarray]):
189197class SklearnBaseModel (BaseModelABC ):
190198 def log_metrics (self ):
191199 super ().log_metrics ()
192- mlflow .sklearn .log_model (self .model , self .name )
200+ # mlflow.sklearn.log_model(self.model, self.name)
193201 mlflow .log_params (self .model .get_params ())
194202
195203class TorchBaseModel (TorchModelTrainMixin , BaseModelABC ):
@@ -257,27 +265,34 @@ class LightGBMModel(SklearnBaseModel):
257265 def log_metrics (self ):
258266 """Callback pour logger les métriques dans MLflow"""
259267 super ().log_metrics ()
260- mlflow . log_metric ( "oob_score" , self . model . oob_score_ )
268+
261269
262270 def init_items (self ):
263- self .model = LGBMClassifier (
264- n_estimators = 100 , learning_rate = 0.1 , max_depth = 6 , random_state = 42
271+ self .model = lgm . LGBMClassifier (
272+ # n_estimators=100, learning_rate=0.1, max_depth=6, random_state=42
265273 )
266274 self .tokenizer = self .tokenizer_class (
267- max_features = 1000 , ngram_range = ( 1 , 2 ), binary = True
275+ max_features = 1000 , min_df = 2 , max_df = 0.95
268276 )
269277
278+ def clean (self , tweet ):
279+ translator = str .maketrans ('' ,'' , string .punctuation )
280+ tweet = tweet .translate (translator )
281+ tweet = re .sub ("^[a-z][A-Z]" , " " ,tweet )
282+ tweet = tweet .lower ()
283+ tweet = ' ' .join (tweet .split ())
284+ return tweet
285+
270286 def train (self ):
271287 """
272288 Train the Random Forest model with progress tracking
273289 """
274290 self .init_mlflow ()
275291 try :
276292 # Vectorisation du texte
277- X_train = self .tokenizer .fit_transform (self .x_train )
293+ X_train = self .tokenizer .fit_transform (self .x_train . apply ( self . clean ) )
278294 self .model .fit (X_train , self .y_train )
279295 super ().train ()
280- # self.model.n_estimators += 10
281296 except Exception as e :
282297 logger .error (f"Erreur pendant l'entraînement: { str (e )} " )
283298 raise
@@ -513,4 +528,5 @@ def load_data(path):
513528 df_tweets = pd .read_csv (path , names = headers , encoding = "latin-1" )
514529 # On prend target 0 negatif 1 positif
515530 df_tweets .loc [:, "target" ] = df_tweets .target .map ({0 : int (0 ), 4 : int (1 )})
531+
516532 return df_tweets
0 commit comments