Skip to content

Commit d81b443

Browse files
committed
fix test, add analyse
1 parent 78be258 commit d81b443

5 files changed

Lines changed: 978 additions & 97 deletions

File tree

analyse.ipynb

Lines changed: 861 additions & 0 deletions
Large diffs are not rendered by default.

src/mixins.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class TorchModelTrainMixin:
2222
checkpoint: str = ""
2323
lr: float = 2e-5
2424
device: torch.device
25+
format_labels_as_long: bool = True
2526

2627
def sample_dataset(self, frac=0.1):
2728
dataset_size = len(self.dataset)
@@ -53,10 +54,11 @@ def params_optim(self, trial):
5354

5455
def objective(self, trial):
5556
kwargs = self.params_optim(trial)
56-
with mlflow.start_run(nested=True):
57-
mlflow.log_params(kwargs)
58-
self.reinit_scheduler_optimizer(**kwargs)
59-
acc = self.train()
57+
if self.tracking:
58+
with mlflow.start_run(nested=True):
59+
mlflow.log_params(kwargs)
60+
self.reinit_scheduler_optimizer(**kwargs)
61+
acc = self.train()
6062
return acc
6163

6264
def get_ddp_dataloader(self, frac=1.0):
@@ -71,7 +73,7 @@ def _train_batch(self, x, y):
7173
if isinstance(inputs, dict) and inputs["input_ids"]:
7274
inputs["input_ids"] = inputs["input_ids"].float()
7375
# todo : fix this case for bert vs lstm
74-
if True and isinstance(y, torch.Tensor) and y.dtype == torch.float32:
76+
if self.format_labels_as_long and isinstance(y, torch.Tensor) and y.dtype == torch.float32:
7577
labels = y.long()
7678
else:
7779
labels = y.float()
@@ -100,10 +102,12 @@ def _train_batch(self, x, y):
100102
else:
101103
raise e
102104
self.optimizer.step()
103-
logger.info(f" Rank {dist.get_rank()} loss {loss.item()} acc {acc}")
104-
mlflow.log_metric("loss", loss.item())
105-
mlflow.log_metric("acc", acc)
106-
mlflow.log_metric("time", time.time())
105+
if dist.is_initialized():
106+
logger.info(f" Rank {dist.get_rank()} loss {loss.item()} acc {acc}")
107+
if self.tracking:
108+
mlflow.log_metric("loss", loss.item())
109+
mlflow.log_metric("acc", acc)
110+
mlflow.log_metric("time", time.time())
107111
del inputs, labels, outputs, loss
108112
gc.collect()
109113
if torch.backends.mps.is_available():

src/ml.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import mlflow
3434
from mlflow.data.pandas_dataset import from_pandas
3535
from mlflow.models import infer_signature
36+
from transformers.models.tapas.modeling_tapas import flatten
37+
3638
from .mixins import TorchModelTrainMixin
3739
from .torch_models import LSTMTorchNN
3840
from .dataset import TweetDataset
@@ -72,11 +74,11 @@ class BaseModelABC(ABC):
7274
dataset_class = None
7375
tokenizer_class = None
7476
batch_size = 32
75-
#artifact_uri = "file:///app/mlruns"
7677

77-
def __init__(self, dataset: pd.DataFrame):
78+
def __init__(self, dataset: pd.DataFrame, tracking: bool = True):
79+
self.tracking = tracking
7880
self.dataset = None
79-
if dataset:
81+
if dataset is not None:
8082
self.original_dataset = dataset
8183
self.x_train, self.x_test, self.x_val, self.y_train, self.y_test, self.y_val = split_data(dataset)
8284
self.name = self.__class__.__name__
@@ -112,10 +114,12 @@ def init_items(self):
112114
"""
113115
self.model = None
114116
self.tokenizer = None
117+
self.tokenizer.fit_transform(self.x_train)
115118

116119
def init_mlflow(self, name:str = ""):
117-
self.run = mlflow.start_run(run_name=name if name else self.name)
118-
self.run_id = self.run.info.run_id
120+
if self.tracking:
121+
self.run = mlflow.start_run(run_name=name if name else self.name)
122+
self.run_id = self.run.info.run_id
119123

120124
def load_checkpoint(self) -> object:
121125
"""
@@ -135,11 +139,13 @@ def preprocessing(self, data):
135139

136140
def confusion_matrix(self):
137141
with NamedTemporaryFile(suffix=".png") as f:
138-
conf_mat = confusion_matrix(self.y_train, self.predict(list(self.x_train)))
142+
y_pred = self.predict(list(self.x_train), flatten=True)
143+
conf_mat = confusion_matrix(self.y_train, y_pred)
139144
group_names = ["True Neg", "False Pos", "False Neg", "True Pos"]
140145
group_counts = [f"{value: 0.0f}" for value in conf_mat.flatten()]
141146
group_percentages = [f"{value:.2%}" for value in conf_mat.flatten() / np.sum(conf_mat)]
142147
labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in zip(group_names, group_counts, group_percentages)]
148+
print(labels)
143149
labels = np.asarray(labels).reshape(2, 2)
144150
sns.heatmap(
145151
conf_mat,
@@ -151,46 +157,49 @@ def confusion_matrix(self):
151157
plt.ylabel("Cluster réels")
152158
plt.savefig(f.name)
153159
plt.close()
154-
mlflow.log_artifact(f.name, "confusion_matrix.png")
160+
if self.tracking:
161+
mlflow.log_artifact(f.name, "confusion_matrix.png")
155162

156163
def train(self):
157164
"""
158165
Train the model here
159166
"""
160167
self.save()
161-
mlflow.set_tag("model_type", self.name)
162-
self.log_metrics()
168+
if self.tracking:
169+
mlflow.set_tag("model_type", self.name)
170+
self.log_metrics()
163171
self.confusion_matrix()
164-
mlflow.log_artifact(self.checkpoint)
165-
signature = infer_signature(self.x_train, self.predict(self.x_train))
166-
dataset = from_pandas(self.original_dataset.loc[self.x_train.index], source="local")
167-
mlflow.log_input(dataset, context="tweet-dataset")
168-
if isinstance(self.model, PreTrainedModel):
169-
mlflow.transformers.log_model(
170-
transformers_model=self.checkpoint,
171-
artifact_path=self.name,
172-
task="text-classification", # important !
173-
tokenizer=self.tokenizer,
174-
signature=signature,
175-
registered_model_name=f"{self.name}-quickstart"
176-
)
177-
elif isinstance(self.model, lgm.LGBMClassifier):
178-
mlflow.lightgbm.log_model(
179-
lgb_model=self.model,
180-
artifact_path=self.name,
181-
signature=signature,
182-
registered_model_name=f"{self.name}-quickstart",
183-
)
184-
else:
185-
mlflow.sklearn.log_model(
186-
sk_model=self.model,
187-
artifact_path=self.name,
188-
signature=signature,
189-
registered_model_name=f"{self.name}-quickstart",
190-
)
191-
mlflow.end_run()
172+
if self.tracking:
173+
mlflow.log_artifact(self.checkpoint)
174+
signature = infer_signature(self.x_train, self.predict(self.x_train))
175+
dataset = from_pandas(self.original_dataset.loc[self.x_train.index], source="local")
176+
mlflow.log_input(dataset, context="tweet-dataset")
177+
if isinstance(self.model, PreTrainedModel):
178+
mlflow.transformers.log_model(
179+
transformers_model=self.checkpoint,
180+
artifact_path=self.name,
181+
task="text-classification", # important !
182+
tokenizer=self.tokenizer,
183+
signature=signature,
184+
registered_model_name=f"{self.name}-quickstart"
185+
)
186+
elif isinstance(self.model, lgm.LGBMClassifier):
187+
mlflow.lightgbm.log_model(
188+
lgb_model=self.model,
189+
artifact_path=self.name,
190+
signature=signature,
191+
registered_model_name=f"{self.name}-quickstart",
192+
)
193+
else:
194+
mlflow.sklearn.log_model(
195+
sk_model=self.model,
196+
artifact_path=self.name,
197+
signature=signature,
198+
registered_model_name=f"{self.name}-quickstart",
199+
)
200+
mlflow.end_run()
192201

193-
def predict(self, x: Union[pd.Series, np.ndarray]):
202+
def predict(self, x: Union[pd.Series, np.ndarray], flatten: bool = True):
194203
"""
195204
Predict the sentiment of the input data
196205
"""
@@ -200,7 +209,6 @@ def predict(self, x: Union[pd.Series, np.ndarray]):
200209
class SklearnBaseModel(BaseModelABC):
201210
def log_metrics(self):
202211
super().log_metrics()
203-
#mlflow.sklearn.log_model(self.model, self.name)
204212
mlflow.log_params(self.model.get_params())
205213

206214
class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
@@ -209,13 +217,13 @@ class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
209217
"""
210218
distributed = False
211219

212-
def __init__(self, dataset: pd.DataFrame):
220+
def __init__(self, dataset: pd.DataFrame, tracking: bool = True):
213221
if self.distributed:
214222
dist.init_process_group("nccl")
215223
if dist.is_initialized():
216224
self.local_rank = dist.get_rank()
217225
torch.cuda.set_device(self.local_rank)
218-
super().__init__(dataset)
226+
super().__init__(dataset, tracking)
219227
if dist.is_initialized():
220228
self.dataloader, self.sampler = self.get_ddp_dataloader()
221229
logger.info(f"Rank {dist.get_rank()} using DDP")
@@ -257,6 +265,8 @@ def init_items(self):
257265
self.tokenizer = self.tokenizer_class(
258266
max_features=1000, ngram_range=(1, 2), binary=True
259267
)
268+
self.tokenizer.fit_transform(self.x_train)
269+
260270

261271
def train(self):
262272
"""
@@ -297,6 +307,7 @@ def init_items(self):
297307
self.tokenizer = self.tokenizer_class(
298308
max_features=1000, min_df=2, max_df=0.95
299309
)
310+
self.tokenizer.fit_transform(self.x_train)
300311

301312
def clean(self, tweet):
302313
translator = str.maketrans('','', string.punctuation)
@@ -348,10 +359,11 @@ def init_items(self):
348359
"""
349360
self.model = LogisticRegression(max_iter=1000,
350361
C=1.7279373898388395,
351-
penalty='l1',
362+
penalty='l2',
352363
n_jobs=4,
353364
verbose=True)
354365
self.tokenizer = self.tokenizer_class()
366+
self.tokenizer.fit_transform(self.x_train)
355367

356368
def objective(self, tokens, params):
357369
with mlflow.start_run(nested=True):
@@ -431,6 +443,7 @@ def load_checkpoint(self):
431443
num_training_steps=total_steps
432444
)
433445
self.criterion = torch.nn.CrossEntropyLoss()
446+
self.model.to(self.device)
434447

435448
def save(self):
436449
if self.distributed and dist.is_initialized():
@@ -439,7 +452,7 @@ def save(self):
439452
self.model.save_pretrained(self.checkpoint)
440453
self.tokenizer.save_pretrained(self.checkpoint)
441454

442-
def predict(self, x: list):
455+
def predict(self, x: list, flatten: bool = False):
443456
predicted_class = []
444457
for i in range(0, len(x), self.batch_size):
445458
inputs = self.preprocessing(x[i:i+self.batch_size])
@@ -448,10 +461,13 @@ def predict(self, x: list):
448461
outputs = self.model(**inputs)
449462
probs = F.softmax(outputs.logits, dim=1)
450463
confidence, categorie = probs.max(dim=1)
451-
predicted_class.extend([{'prediction': categorie.item(),
452-
'confidence': confidence.item()}
453-
for confidence, categorie in zip(confidence, categorie)
454-
])
464+
if not flatten:
465+
predicted_class.extend([{'prediction': categorie.item(),
466+
'confidence': confidence.item()}
467+
for confidence, categorie in zip(confidence, categorie)
468+
])
469+
else:
470+
predicted_class.extend([c.item() for c in categorie])
455471
return predicted_class
456472

457473
class RobertaModel(BertModel):
@@ -474,7 +490,7 @@ class LSTMModel(TorchBaseModel):
474490
out_features = 1
475491
lr = 1e-4
476492
device = DEVICE
477-
# torch.nn.CrossEntropyLoss()
493+
format_labels_as_long: bool = False
478494

479495
@property
480496
def get_metrics(self) -> dict:
@@ -540,7 +556,7 @@ def save(self):
540556
self.tokenizer.save_pretrained(self.checkpoint)
541557
torch.save(self.model.state_dict(), self.checkpoint + "/model.pth")
542558

543-
def predict(self, x):
559+
def predict(self, x, flatten: bool = True):
544560
predicted_class = []
545561
for i in range(0, len(x), self.batch_size):
546562
inputs = self.preprocessing(x[i:i+self.batch_size])

src/server.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from fastapi.routing import APIRouter
55
from fastapi.requests import Request
66
from fastapi.responses import Response
7-
from fastapi import Form, WebSocket
7+
from fastapi import WebSocket
88
from fastapi.templating import Jinja2Templates
9-
from multiprocessing import Process, Pipe
9+
from multiprocessing import Pipe
1010
from multiprocessing.pool import Pool
1111
from typing import List
1212
from rich.logging import RichHandler
@@ -15,10 +15,9 @@
1515
import pymongo
1616
import uuid
1717
import time
18-
import random
1918
import json
20-
from .ml import BertModel, RobertaModel
21-
from .models import Tweet, Sentiment
19+
from .ml import RobertaModel
20+
from .models import Tweet
2221

2322
# Configuration des métriques Prometheus
2423
PREDICTION_COUNT = prom.Counter(
@@ -55,13 +54,21 @@ def get_pool():
5554
return pool
5655

5756

58-
def run_predict(text: List[Tweet], sender):
57+
def run_predict(text: List[Tweet], sender, save_db=False):
5958
FLAG_START = "started"
6059
FLAG_DONE = "done"
6160
logger.info(f"prediction started {text}")
6261
sender.send(FLAG_START)
6362
start_time = time.time()
6463
result = RobertaModel(dataset=None).predict([t.text for t in text])
64+
if save_db:
65+
with pymongo.MongoClient(MONGO_URI) as client:
66+
db = client["sentiment_analyses"]
67+
collection = db["tweets"]
68+
# Convertir la liste de tweets en liste de documents
69+
tweets = [{"text": str(t), "prediction": r['prediction'], "confidence": r['confidence']} for r, t in zip(result, text)]
70+
collection.insert_many(tweets)
71+
logger.info(f"tweets added to db: {len(tweets)} tweets")
6572
logger.info(f"prediction done {text}")
6673
sender.send(FLAG_DONE)
6774
result = [{'prediction': r['prediction'], 'confidence': r['confidence'], 'text': t.text} for r, t in zip(result, text)]
@@ -70,9 +77,8 @@ def run_predict(text: List[Tweet], sender):
7077

7178
class PredictApp:
7279
ACK_TIMEOUT = 1.0
73-
7480

75-
def __init__(self):
81+
def __init__(self, save_db=False):
7682
self.router = APIRouter()
7783
self.router.add_api_route("/", self.get, methods=["GET"])
7884
self.router.add_api_route("/predict", self.predict, methods=["POST"])
@@ -81,6 +87,7 @@ def __init__(self):
8187
self.active_connections = {}
8288
self.tasks = {}
8389
self.pipes = {}
90+
self.save_db = save_db
8491

8592
async def get(self, request: Request):
8693
""""""
@@ -178,16 +185,9 @@ async def predict(self, request: Request, text: List[Tweet]):
178185
"""
179186
Predict
180187
"""
181-
with pymongo.MongoClient(MONGO_URI) as client:
182-
db = client["sentiment_analyses"]
183-
collection = db["tweets"]
184-
# Convertir la liste de tweets en liste de documents
185-
tweets = [{"text": str(tweet)} for tweet in text]
186-
collection.insert_many(tweets)
187-
logger.info(f"tweets added to db: {len(tweets)} tweets")
188188
p = get_pool()
189189
pipe = Pipe()
190-
result = p.apply_async(run_predict, (text, pipe[1]))
190+
result = p.apply_async(run_predict, (text, pipe[1], self.save_db))
191191
PREDICTION_STATUS.labels("bert").inc()
192192
task_id = str(uuid.uuid4())
193193
self.tasks[task_id] = result
@@ -196,6 +196,6 @@ async def predict(self, request: Request, text: List[Tweet]):
196196
return {"task_id": task_id, "status": "processing"}
197197

198198

199-
predict_app = PredictApp()
199+
predict_app = PredictApp(save_db=True)
200200
app.include_router(predict_app.router)
201201

0 commit comments

Comments
 (0)