1111import torch
1212from tqdm import tqdm
1313from torch .utils .data import DataLoader
14+ import torch .nn .functional as F
1415from sklearn .linear_model import LogisticRegression
1516from lightgbm import LGBMClassifier
1617from sklearn .model_selection import train_test_split , cross_val_score
1718from sklearn .feature_extraction .text import TfidfVectorizer , CountVectorizer
1819from sklearn .metrics import confusion_matrix , classification_report
20+ from sklearn .ensemble import RandomForestClassifier
1921import seaborn as sns
2022from transformers import AutoTokenizer , AutoModelForSequenceClassification
2123import pandas as pd
22- from sklearn .ensemble import RandomForestClassifier
2324from skopt import BayesSearchCV , gp_minimize
2425from skopt .space import Real , Categorical
2526from skopt .utils import use_named_args
2627import logging
27- import torch . nn . functional as F
28+ from transformers import PreTrainedModel
2829import mlflow
2930from mlflow .data .pandas_dataset import from_pandas
3031from mlflow .models import infer_signature
4445 DEVICE = torch .device ("cpu" )
4546 logger .info ("Using CPU" )
4647
47- DEVICE = torch .device ("cpu" )
48+ # DEVICE = torch.device("cpu")
4849
4950SENTIMENT_LABELS = {
5051 0 : "😡 unsatisfy" ,
@@ -67,7 +68,7 @@ class BaseModelABC(ABC):
6768 dataset_class = None
6869 tokenizer_class = None
6970 batch_size = 32
70- artifact_uri = "file:///app/mlruns"
71+ # artifact_uri = "file:///app/mlruns"
7172
7273 def __init__ (self , dataset : pd .DataFrame ):
7374 self .original_dataset = dataset
@@ -146,7 +147,7 @@ def confusion_matrix(self):
146147 plt .ylabel ("Cluster réels" )
147148 plt .savefig (f .name )
148149 plt .close ()
149- mlflow .log_artifact (f .name , "confusion_matrix.png" )
150+ mlflow .log_artifact (f .name ) # , "confusion_matrix.png")
150151
151152 def train (self ):
152153 """
@@ -160,12 +161,22 @@ def train(self):
160161 signature = infer_signature (self .x_train , self .predict (self .x_train ))
161162 dataset = from_pandas (self .original_dataset .loc [self .x_train .index ], source = "local" )
162163 mlflow .log_input (dataset , context = "tweet-dataset" )
163- model_info = mlflow .sklearn .log_model (
164- sk_model = self .model ,
165- artifact_path = self .name ,
166- signature = signature ,
167- registered_model_name = f"{ self .name } -quickstart" ,
168- )
164+ if isinstance (self .model , PreTrainedModel ):
165+ mlflow .transformers .log_model (
166+ transformers_model = self .checkpoint ,
167+ artifact_path = self .name ,
168+ task = "text-classification" , # important !
169+ tokenizer = self .tokenizer ,
170+ signature = signature ,
171+ registered_model_name = f"{ self .name } -quickstart"
172+ )
173+ else :
174+ mlflow .sklearn .log_model (
175+ sk_model = self .model ,
176+ artifact_path = self .name ,
177+ signature = signature ,
178+ registered_model_name = f"{ self .name } -quickstart" ,
179+ )
169180 mlflow .end_run ()
170181
171182 def predict (self , x : Union [pd .Series , np .ndarray ]):
@@ -341,9 +352,9 @@ class BertModel(TorchBaseModel):
341352 model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
342353 dataset_class = TweetDataset
343354 epoch = 1
344- batch_size = 200
355+ batch_size = 100
345356 out_features = 2
346- lr = 2e-5
357+ lr = 2.561e-4
347358 device = DEVICE
348359
349360 def load_checkpoint (self ):
@@ -362,7 +373,7 @@ def load_checkpoint(self):
362373 )
363374 self .optimizer = torch .optim .Adam (self .model .parameters (), lr = self .lr )
364375 self .scheduler = torch .optim .lr_scheduler .StepLR (
365- self .optimizer , step_size = 5 , gamma = 0.1
376+ self .optimizer , step_size = 8 , gamma = 0.248
366377 )
367378 self .criterion = torch .nn .CrossEntropyLoss ()
368379
@@ -371,14 +382,24 @@ def save(self):
371382 self .tokenizer .save_pretrained (self .checkpoint )
372383
373384 def predict (self , x : list ):
374- inputs = self .preprocessing (x )
375- inputs = inputs .to (self .device )
376- with torch .no_grad ():
377- outputs = self .model (** inputs )
378- probs = F .softmax (outputs .logits , dim = 1 )
379- predicted_class = torch .argmax (probs , dim = 1 ).tolist ()
385+ predicted_class = []
386+ for i in range (0 , len (x ), self .batch_size ):
387+ inputs = self .preprocessing (x [i :i + self .batch_size ])
388+ inputs = inputs .to (self .device )
389+ with torch .no_grad ():
390+ outputs = self .model (** inputs )
391+ probs = F .softmax (outputs .logits , dim = 1 )
392+ predicted_class .extend (torch .argmax (probs , dim = 1 ).tolist ())
380393 return predicted_class
381394
395+ class RobertaModel (BertModel ):
396+ """
397+ Using a roberta base sentiment to predict tweet sentiments
398+ """
399+ model_name = "cardiffnlp/twitter-roberta-base-sentiment"
400+ tokenizer_name = "cardiffnlp/twitter-roberta-base-sentiment"
401+ checkpoint = "checkpoints/roberta"
402+
382403
383404class LSTMModel (TorchBaseModel ):
384405 checkpoint = "checkpoints/lstm"
0 commit comments