Skip to content

Commit 0c83668

Browse files
committed
uptuna
1 parent 73b36f5 commit 0c83668

5 files changed

Lines changed: 110 additions & 30 deletions

File tree

mlruns/0/meta.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
artifact_location: file:///app/mlruns/0
2-
creation_time: 1744912772300
1+
artifact_location: file:///workspace/sentimental_analyses/mlruns/0
2+
creation_time: 1744984721760
33
experiment_id: '0'
4-
last_update_time: 1744912772300
4+
last_update_time: 1744984721760
55
lifecycle_stage: active
66
name: Default

src/mixins.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import mlflow
22
import torch
33
from tqdm import tqdm
4+
import numpy as np
45
import gc
56
import time
67
import logging
8+
import optuna
9+
import random
10+
from torch.utils.data import Subset, DataLoader
711

812
logger = logging.getLogger(__name__)
913

@@ -17,6 +21,40 @@ class TorchModelTrainMixin:
1721
lr: float = 2e-5
1822
device: torch.device
1923

24+
def get_sampled_dataloader(self, frac=0.1):
25+
dataset_size = len(self.dataset)
26+
sample_size = int(frac * dataset_size)
27+
indices = random.sample(range(dataset_size), sample_size)
28+
sampled_dataset = Subset(self.dataset, indices)
29+
return DataLoader(sampled_dataset, batch_size=self.batch_size, shuffle=True)
30+
31+
def optuna_train(self, run_name:str = "", n_trials:int=30, frac=0.1):
32+
self.init_mlflow(run_name)
33+
self.dataloader = self.get_sampled_dataloader(frac=frac)
34+
self.x_val = self.x_val.sample(frac=frac, random_state=42)
35+
self.y_val = self.y_val.sample(frac=frac, random_state=42)
36+
self.x_train = self.x_train.sample(frac=frac, random_state=42)
37+
self.y_train = self.y_train.sample(frac=frac, random_state=42)
38+
study = optuna.create_study(direction="maximize",
39+
pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=1))
40+
study.optimize(self.objective, n_trials=n_trials)
41+
42+
43+
def objective(self, trial):
44+
lr = trial.suggest_loguniform('lr', 1e-6, 1e-3)
45+
gamma = trial.suggest_float('gamma', 0.1, 0.9)
46+
step_size = trial.suggest_int('step_size', 2, 10)
47+
with mlflow.start_run(nested=True):
48+
mlflow.log_params({
49+
"lr": lr,
50+
"gamma": gamma,
51+
"step_size": step_size
52+
})
53+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
54+
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
55+
acc = self.train()
56+
return acc
57+
2058
def _train_batch(self, x, y):
2159
inputs = self.tokenizer(x, return_tensors="pt", truncation=True, padding=True)
2260
if isinstance(inputs, dict) and inputs["input_ids"]:
@@ -38,7 +76,7 @@ def _train_batch(self, x, y):
3876
acc = correct / len(labels)
3977
loss.backward()
4078
self.optimizer.step()
41-
logger.info(loss.item())
79+
logger.info(f"loss {loss.item()}")
4280
mlflow.log_metric("loss", loss.item())
4381
mlflow.log_metric("acc", acc)
4482
mlflow.log_metric("time", time.time())
@@ -47,17 +85,22 @@ def _train_batch(self, x, y):
4785
if torch.backends.mps.is_available():
4886
torch.mps.empty_cache()
4987
time.sleep(0.2)
88+
return acc
5089

5190
def train(self):
52-
self.init_mlflow()
91+
try:
92+
self.init_mlflow()
93+
except Exception:
94+
logger.info("mlflow run already started, you had launched train with optuna")
95+
pass
5396
self.model.train()
5497
self.model.to(self.device)
98+
acc = []
5599
for epoch in tqdm(range(self.epoch)):
56100
for tweets, labels in tqdm(self.dataloader):
57101
try:
58-
self._train_batch(tweets, labels.float())
102+
acc.append(self._train_batch(tweets, labels.float()))
59103
except RuntimeError as e:
60-
raise e
61104
logger.error(e)
62105
del tweets, labels, self.optimizer
63106
gc.collect()
@@ -79,6 +122,7 @@ def train(self):
79122
)
80123
self.scheduler.step()
81124
super().train()
125+
return sum(acc)/len(acc)
82126

83127

84128
class SklearnModelTrainMixin:

src/ml.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,21 @@
1111
import torch
1212
from tqdm import tqdm
1313
from torch.utils.data import DataLoader
14+
import torch.nn.functional as F
1415
from sklearn.linear_model import LogisticRegression
1516
from lightgbm import LGBMClassifier
1617
from sklearn.model_selection import train_test_split, cross_val_score
1718
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
1819
from sklearn.metrics import confusion_matrix, classification_report
20+
from sklearn.ensemble import RandomForestClassifier
1921
import seaborn as sns
2022
from transformers import AutoTokenizer, AutoModelForSequenceClassification
2123
import pandas as pd
22-
from sklearn.ensemble import RandomForestClassifier
2324
from skopt import BayesSearchCV, gp_minimize
2425
from skopt.space import Real, Categorical
2526
from skopt.utils import use_named_args
2627
import logging
27-
import torch.nn.functional as F
28+
from transformers import PreTrainedModel
2829
import mlflow
2930
from mlflow.data.pandas_dataset import from_pandas
3031
from mlflow.models import infer_signature
@@ -44,7 +45,7 @@
4445
DEVICE = torch.device("cpu")
4546
logger.info("Using CPU")
4647

47-
DEVICE = torch.device("cpu")
48+
#DEVICE = torch.device("cpu")
4849

4950
SENTIMENT_LABELS = {
5051
0: "😡 unsatisfy",
@@ -67,7 +68,7 @@ class BaseModelABC(ABC):
6768
dataset_class = None
6869
tokenizer_class = None
6970
batch_size = 32
70-
artifact_uri = "file:///app/mlruns"
71+
#artifact_uri = "file:///app/mlruns"
7172

7273
def __init__(self, dataset: pd.DataFrame):
7374
self.original_dataset = dataset
@@ -146,7 +147,7 @@ def confusion_matrix(self):
146147
plt.ylabel("Cluster réels")
147148
plt.savefig(f.name)
148149
plt.close()
149-
mlflow.log_artifact(f.name, "confusion_matrix.png")
150+
mlflow.log_artifact(f.name)#, "confusion_matrix.png")
150151

151152
def train(self):
152153
"""
@@ -160,12 +161,22 @@ def train(self):
160161
signature = infer_signature(self.x_train, self.predict(self.x_train))
161162
dataset = from_pandas(self.original_dataset.loc[self.x_train.index], source="local")
162163
mlflow.log_input(dataset, context="tweet-dataset")
163-
model_info = mlflow.sklearn.log_model(
164-
sk_model=self.model,
165-
artifact_path=self.name,
166-
signature=signature,
167-
registered_model_name=f"{self.name}-quickstart",
168-
)
164+
if isinstance(self.model, PreTrainedModel):
165+
mlflow.transformers.log_model(
166+
transformers_model=self.checkpoint,
167+
artifact_path=self.name,
168+
task="text-classification", # important !
169+
tokenizer=self.tokenizer,
170+
signature=signature,
171+
registered_model_name=f"{self.name}-quickstart"
172+
)
173+
else:
174+
mlflow.sklearn.log_model(
175+
sk_model=self.model,
176+
artifact_path=self.name,
177+
signature=signature,
178+
registered_model_name=f"{self.name}-quickstart",
179+
)
169180
mlflow.end_run()
170181

171182
def predict(self, x: Union[pd.Series, np.ndarray]):
@@ -341,9 +352,9 @@ class BertModel(TorchBaseModel):
341352
model_name = "nlptown/bert-base-multilingual-uncased-sentiment"
342353
dataset_class = TweetDataset
343354
epoch = 1
344-
batch_size = 200
355+
batch_size = 100
345356
out_features = 2
346-
lr = 2e-5
357+
lr = 2.561e-4
347358
device = DEVICE
348359

349360
def load_checkpoint(self):
@@ -362,7 +373,7 @@ def load_checkpoint(self):
362373
)
363374
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
364375
self.scheduler = torch.optim.lr_scheduler.StepLR(
365-
self.optimizer, step_size=5, gamma=0.1
376+
self.optimizer, step_size=8, gamma=0.248
366377
)
367378
self.criterion = torch.nn.CrossEntropyLoss()
368379

@@ -371,14 +382,24 @@ def save(self):
371382
self.tokenizer.save_pretrained(self.checkpoint)
372383

373384
def predict(self, x: list):
374-
inputs = self.preprocessing(x)
375-
inputs = inputs.to(self.device)
376-
with torch.no_grad():
377-
outputs = self.model(**inputs)
378-
probs = F.softmax(outputs.logits, dim=1)
379-
predicted_class = torch.argmax(probs, dim=1).tolist()
385+
predicted_class = []
386+
for i in range(0, len(x), self.batch_size):
387+
inputs = self.preprocessing(x[i:i+self.batch_size])
388+
inputs = inputs.to(self.device)
389+
with torch.no_grad():
390+
outputs = self.model(**inputs)
391+
probs = F.softmax(outputs.logits, dim=1)
392+
predicted_class.extend(torch.argmax(probs, dim=1).tolist())
380393
return predicted_class
381394

395+
class RobertaModel(BertModel):
396+
"""
397+
Using a roberta base sentiment to predict tweet sentiments
398+
"""
399+
model_name = "cardiffnlp/twitter-roberta-base-sentiment"
400+
tokenizer_name = "cardiffnlp/twitter-roberta-base-sentiment"
401+
checkpoint = "checkpoints/roberta"
402+
382403

383404
class LSTMModel(TorchBaseModel):
384405
checkpoint = "checkpoints/lstm"

src/tests/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
LogisticRegressionModel,
77
load_data,
88
BertModel,
9+
RobertaModel,
910
LSTMModel,
1011
RandomForestModel,
1112
LightGBMModel,
1213
)
14+
from transformers import PreTrainedModel
1315
from ..server import app
1416

1517

@@ -58,6 +60,19 @@ def test_predict(self):
5860
def test_confusion_matrix(self):
5961
self.model.confusion_matrix()
6062

63+
def test_optuna_train(self):
64+
self.model.optuna_train(n_trials=5)
65+
66+
class TestRobertaModel(BaseTest):
67+
class_model = RobertaModel
68+
69+
def test_optuna_train(self):
70+
self.model.optuna_train(n_trials=5)
71+
72+
def test_predict(self):
73+
result = self.model.predict(list(self.x_test))
74+
assert result == [1, 1, 0, 0, 0, 0]
75+
6176

6277
class TestLSTMModel(BaseTest):
6378
class_model = LSTMModel

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from src.ml import load_data, RandomForestModel, LogisticRegressionModel, BertModel
1+
from src.ml import load_data, RandomForestModel, LogisticRegressionModel, BertModel, RobertaModel, LSTMModel
22
import logging
33
from rich.logging import RichHandler
44

@@ -7,5 +7,5 @@
77
)
88
file = "../data/training.1600000.processed.noemoticon.csv"
99
original_df = load_data(file)
10-
model = BertModel(original_df)
11-
model.train()
10+
model = LSTMModel(original_df)
11+
model.optuna_train(n_trials=5, frac=0.01)

0 commit comments

Comments
 (0)