Skip to content

Commit d6aed1c

Browse files
committed
wip cloud gpu training
1 parent 93b6a50 commit d6aed1c

12 files changed

Lines changed: 149 additions & 95 deletions

File tree

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ services:
3636
volumes:
3737
- grafana-storage:/var/lib/grafana
3838
- ./grafana/dashboards:/etc/grafana/dashboards
39-
- ./grafana/dashboards.yml:/etc/grafana/provisioning/dashboards/tweet_dashboards.yml
39+
- ./grafana/dashboards.yaml:/etc/grafana/provisioning/dashboards/tweet_dashboards.yaml
4040
- ./grafana/datasources:/etc/grafana/provisioning/datasources
4141
- ./grafana/alerting:/etc/grafana/alerting
42-
- ./grafana/alerting.yml:/etc/grafana/provisioning/alerting/tweet_alerts.yml
42+
- ./grafana/alertings.yaml:/etc/grafana/provisioning/alerting/tweet_alerts.yaml
4343
environment:
4444
- GF_SECURITY_ADMIN_PASSWORD=admin
4545
- GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH=/etc/grafana/dashboards/tweet_dashboard.json

grafana/alerting/rules.json

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
{
2+
"apiVersion": 1,
3+
"groups": [
4+
{
5+
"orgId": 1,
6+
"name": "job alert",
7+
"folder": "bert",
8+
"interval": "1m",
9+
"rules": [
10+
{
11+
"uid": "dej6cljqgxekge",
12+
"title": "Alert max jobs",
13+
"condition": "C",
14+
"data": [
15+
{
16+
"refId": "A",
17+
"relativeTimeRange": {
18+
"from": 600,
19+
"to": 0
20+
},
21+
"datasourceUid": "cej41rlo2r11cd",
22+
"model": {
23+
"disableTextWrap": false,
24+
"editorMode": "builder",
25+
"expr": "prediction_status{model=\"bert\"}",
26+
"fullMetaSearch": false,
27+
"includeNullMetadata": true,
28+
"instant": true,
29+
"intervalMs": 1000,
30+
"legendFormat": "__auto",
31+
"maxDataPoints": 43200,
32+
"range": false,
33+
"refId": "A",
34+
"useBackend": false
35+
}
36+
},
37+
{
38+
"refId": "C",
39+
"relativeTimeRange": {
40+
"from": 0,
41+
"to": 0
42+
},
43+
"datasourceUid": "__expr__",
44+
"model": {
45+
"conditions": [
46+
{
47+
"evaluator": {
48+
"params": [
49+
5
50+
],
51+
"type": "gt"
52+
},
53+
"operator": {
54+
"type": "and"
55+
},
56+
"query": {
57+
"params": [
58+
"C"
59+
]
60+
},
61+
"reducer": {
62+
"params": [],
63+
"type": "last"
64+
},
65+
"type": "query"
66+
}
67+
],
68+
"datasource": {
69+
"type": "__expr__",
70+
"uid": "__expr__"
71+
},
72+
"expression": "A",
73+
"intervalMs": 1000,
74+
"maxDataPoints": 43200,
75+
"refId": "C",
76+
"type": "threshold"
77+
}
78+
}
79+
],
80+
"noDataState": "NoData",
81+
"execErrState": "Error",
82+
"for": "1m",
83+
"annotations": {
84+
"description": "Task is up to 5 jobs",
85+
"summary": "Alert Sentiment analysis"
86+
},
87+
"isPaused": false,
88+
"notification_settings": {
89+
"receiver": "admin-email"
90+
}
91+
}
92+
]
93+
}
94+
]
95+
}

grafana/alerting/rules.yaml

Lines changed: 0 additions & 63 deletions
This file was deleted.

grafana/alertings.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
apiVersion: 1
1+
# apiVersion: 1
22

3-
groups:
4-
- name: alert-rules
5-
folder: Tweet Alerts
6-
orgId: 1
7-
interval: 30s
8-
rules:
9-
- file: /etc/grafana/alerting/rules.yml
3+
# groups:
4+
# - name: alert-rules
5+
# folder: Tweet Alerts
6+
# orgId: 1
7+
# interval: 30s
8+
# rules:
9+
# - file: /etc/grafana/alerting/rules.json

grafana/datasources/datasources.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ apiVersion: 1
33
datasources:
44
- name: Prometheus
55
type: prometheus
6+
uid: DS_TWEET_SENTIMENT SERVER METRICS
67
access: proxy
78
orgId: 1
89
url: http://prometheus:9090
@@ -11,6 +12,7 @@ datasources:
1112

1213
- name: Loki
1314
type: loki
15+
uid: DS_LOKI
1416
access: proxy
1517
orgId: 1
1618
url: http://loki:3100

mlruns/0/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
artifact_location: file:///workspace/sentimental_analyses/mlruns/0
1+
artifact_location: file:///Users/wonters/Desktop/openclassroom/projets/sentiment_analyses/sentimental_analyses/mlruns/0
22
creation_time: 1744984721760
33
experiment_id: '0'
44
last_update_time: 1744984721760

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,5 @@ scikit-optimize==0.10.2
101101
lightgbm==4.1.0
102102
pymongo==4.12.0
103103
prometheus-client==0.21.1
104+
optuna==3.4.0
104105

scripts/cloud-gpu.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#! /bin/bash
2+
sudo apt-get install gcc make -y
3+
sudo apt install build-essential linux-headers-$(uname -r)
4+
wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
5+
sudo sh ./cuda_11.8.0_520.61.05_linux.run --silent --driver --toolkit

src/mixins.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _train_batch(self, x, y):
6464
self.optimizer.zero_grad()
6565
inputs = inputs.to(self.device)
6666
labels = labels.to(self.device)
67+
# Give input_ids and attention masks
6768
outputs = self.model(**inputs)
6869
try:
6970
loss = self.criterion(outputs.logits, labels)

src/ml.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import joblib
33
import os
4+
import re
5+
import string
46
from functools import partial
57
from pathlib import Path
68
from abc import ABC
@@ -13,7 +15,7 @@
1315
from torch.utils.data import DataLoader
1416
import torch.nn.functional as F
1517
from sklearn.linear_model import LogisticRegression
16-
from lightgbm import LGBMClassifier
18+
import lightgbm as lgm
1719
from sklearn.model_selection import train_test_split, cross_val_score
1820
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
1921
from sklearn.metrics import confusion_matrix, classification_report
@@ -112,7 +114,6 @@ def init_mlflow(self, name:str = ""):
112114
self.run = mlflow.start_run(run_name=name if name else self.name)
113115
self.run_id = self.run.info.run_id
114116

115-
116117
def load_checkpoint(self) -> object:
117118
"""
118119
Logic to load the model from a checkpoint or create a new one
@@ -147,7 +148,7 @@ def confusion_matrix(self):
147148
plt.ylabel("Cluster réels")
148149
plt.savefig(f.name)
149150
plt.close()
150-
mlflow.log_artifact(f.name)#, "confusion_matrix.png")
151+
mlflow.log_artifact(f.name, "confusion_matrix.png")
151152

152153
def train(self):
153154
"""
@@ -170,6 +171,13 @@ def train(self):
170171
signature=signature,
171172
registered_model_name=f"{self.name}-quickstart"
172173
)
174+
elif isinstance(self.model, lgm.LGBMClassifier):
175+
mlflow.lightgbm.log_model(
176+
lgb_model=self.model,
177+
artifact_path=self.name,
178+
signature=signature,
179+
registered_model_name=f"{self.name}-quickstart",
180+
)
173181
else:
174182
mlflow.sklearn.log_model(
175183
sk_model=self.model,
@@ -189,7 +197,7 @@ def predict(self, x: Union[pd.Series, np.ndarray]):
189197
class SklearnBaseModel(BaseModelABC):
190198
def log_metrics(self):
191199
super().log_metrics()
192-
mlflow.sklearn.log_model(self.model, self.name)
200+
#mlflow.sklearn.log_model(self.model, self.name)
193201
mlflow.log_params(self.model.get_params())
194202

195203
class TorchBaseModel(TorchModelTrainMixin, BaseModelABC):
@@ -257,27 +265,34 @@ class LightGBMModel(SklearnBaseModel):
257265
def log_metrics(self):
258266
"""Callback pour logger les métriques dans MLflow"""
259267
super().log_metrics()
260-
mlflow.log_metric("oob_score", self.model.oob_score_)
268+
261269

262270
def init_items(self):
263-
self.model = LGBMClassifier(
264-
n_estimators=100, learning_rate=0.1, max_depth=6, random_state=42
271+
self.model = lgm.LGBMClassifier(
272+
#n_estimators=100, learning_rate=0.1, max_depth=6, random_state=42
265273
)
266274
self.tokenizer = self.tokenizer_class(
267-
max_features=1000, ngram_range=(1, 2), binary=True
275+
max_features=1000, min_df=2, max_df=0.95
268276
)
269277

278+
def clean(self, tweet):
279+
translator = str.maketrans('','', string.punctuation)
280+
tweet = tweet.translate(translator)
281+
tweet = re.sub("^[a-z][A-Z]", " ",tweet)
282+
tweet = tweet.lower()
283+
tweet = ' '.join(tweet.split())
284+
return tweet
285+
270286
def train(self):
271287
"""
272288
Train the Random Forest model with progress tracking
273289
"""
274290
self.init_mlflow()
275291
try:
276292
# Vectorisation du texte
277-
X_train = self.tokenizer.fit_transform(self.x_train)
293+
X_train = self.tokenizer.fit_transform(self.x_train.apply(self.clean))
278294
self.model.fit(X_train, self.y_train)
279295
super().train()
280-
# self.model.n_estimators += 10
281296
except Exception as e:
282297
logger.error(f"Erreur pendant l'entraînement: {str(e)}")
283298
raise
@@ -513,4 +528,5 @@ def load_data(path):
513528
df_tweets = pd.read_csv(path, names=headers, encoding="latin-1")
514529
# On prend target 0 negatif 1 positif
515530
df_tweets.loc[:, "target"] = df_tweets.target.map({0: int(0), 4: int(1)})
531+
516532
return df_tweets

0 commit comments

Comments
 (0)