Skip to content

Commit b7999e0

Browse files
committed
Merge branch 'develop'
2 parents fdfb619 + f5621d0 commit b7999e0

4 files changed

Lines changed: 209 additions & 21 deletions

File tree

.idea/sentimental_analyses.iml

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
# Sentimental analyses with MLFLOW and models Wrappers
22

3+
[![version](https://img.shields.io/badge/version-1.0.0-green.svg)](https://semver.org)
4+
5+
## Table of content
6+
- [Overview](#overview)
7+
- [Architecture](#architecture)
8+
- [Install](#install)
9+
- [Usage](#usage)
10+
- [Contributing](#contributing)
11+
- [Production](#production)
12+
- [Monitoring](#monitoring)
13+
- [Api](#api)
14+
- [License](#license)
15+
- [Author](#author)
16+
- [Thanks](#thanks)
17+
18+
## Overview
19+
320
Tweet sentimental analyses with different models.
421

522
Four wrapper of models:
@@ -15,18 +32,7 @@ MLFlow is used to list all experiments and easily commpare results for several d
1532
Optuna is used to optimise parameters. It run a set of experiments with a variation of parameters and select the best configuration
1633
maximising the accuracy.
1734

18-
19-
20-
The app is dockerised and can be installed launching the command
21-
```bash
22-
docker compose up
23-
```
24-
or to run in background
25-
```bash
26-
docker compose up -d
27-
```
28-
29-
## Access and architecture
35+
## Architecture
3036
The application contains alerting system and monitoring on grafana on port 3000
3137
APP PORT
3238
MLFLOW 5001
@@ -42,15 +48,36 @@ Prometheus send metrics as the number of prediction running.
4248
An alert is send by mail when number of predictions in concurrency are up to 5.
4349
An alert is send when the result of the prediction is too bad, probability < 0.5.
4450

45-
## Installation in dev
46-
# Install uv (Rust package to fastly install package)
51+
52+
## Install
53+
54+
The app is dockerised and can be installed launching the command
55+
```bash
56+
docker compose up
57+
```
58+
or to run in background
59+
```bash
60+
docker compose up -d
61+
```
62+
63+
64+
## Contributing
65+
#### Install uv (Rust package to fastly install package)
4766
```bash
4867
curl -Ls https://astral.sh/uv/install.sh | bash
4968
export PATH="$HOME/.cargo/bin:$PATH"
5069
```
70+
#### Source the code in the container
71+
Modify the docker-compose.yaml to add the source code as volume
72+
```bash
73+
volumes:
74+
- ./src:/app/src/
75+
- ./mlruns:/app/mlruns/
76+
```
5177

52-
## OVH Train with AI train
78+
## Usage
5379

80+
### OVH Train with AI train
5481
Create an object storage on OVH managed with ovhai cli
5582
The secret key is obtain clicking on the user object storage line 'access secret key'
5683
```bash
@@ -65,7 +92,7 @@ Credentials are stored in ~/.config/ovhai/context.json
6592
uv pip install boto3 awscli ovhai
6693
```
6794

68-
## Run on multi GPU
95+
### Run on multi GPU
6996
DEBUG
7097
```bash
7198
export TORCH_DISTRIBUTED_DEBUG=DETAIL
@@ -74,16 +101,36 @@ export TORCH_DISTRIBUTED_DEBUG=DETAIL
74101
python -m torch.distributed.run --nproc_per_node=2 train.py
75102
```
76103

77-
## Tests
104+
### Tests
78105
```bash
79106
pytest src/tests
80107
```
81108

82-
## Launch a test to verify the prection from the API
109+
### Launch a test to verify the prection from the API
83110
Go on 127.0.0.1:5000, tap your tweet and click on predict button
84111

85-
Par requête http
112+
113+
## Production
114+
An exemple deployment is available on https://tweetsentiment.shift.python.software.fr
115+
116+
## Monitoring
117+
Add alert and monitoring and dashboard on grafana on your local instance
118+
and save them in grafana folder.
119+
Reload grafana and they will be available on http://localhost:3000 as provisionning templates
120+
121+
## Api
122+
You can contact the api example
123+
or change the url on the script predict_client.py to test your instance
86124
```bash
87125
export $(cat .env | xargs)
88-
python post.py
126+
python predict_client.py
89127
```
128+
## License
129+
130+
MIT License
131+
132+
## Author
133+
Shift python software
134+
135+
## Thanks
136+
Thanks to all contributors

src/tests/test_ml.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#import lightgbm
2+
import pytest
3+
import time
4+
from fastapi.testclient import TestClient
5+
from ..ml import (
6+
LogisticRegressionModel,
7+
load_data,
8+
BertModel,
9+
RobertaModel,
10+
LSTMModel,
11+
RandomForestModel,
12+
LightGBMModel,
13+
)
14+
from transformers import PreTrainedModel
15+
from ..server import app
16+
17+
class BaseTest:
18+
file = "data/tweets_test_train.csv"
19+
class_model = None
20+
21+
@classmethod
22+
def setup_class(cls):
23+
df = load_data(cls.file)
24+
cls.model = cls.class_model(dataset=df)
25+
26+
def test_train(self):
27+
self.model.train()
28+
29+
def test_tokenizer(self):
30+
self.model.tokenizer.transform(self.model.x_train)
31+
32+
def test_preprocessing(self):
33+
self.model.preprocessing(self.model.x_train)
34+
35+
36+
class TestLogisticRegressionModel(BaseTest):
37+
class_model = LogisticRegressionModel
38+
39+
def test_predict(self):
40+
result = self.model.predict(list(self.model.x_test))
41+
print(result, self.model.y_test.values)
42+
assert result.tolist() == [0, 1, 0, 0, 0, 0]
43+
44+
45+
class TestLightGBMModel(BaseTest):
46+
class_model = LightGBMModel
47+
48+
def test_train(self):
49+
self.model.train()
50+
51+
52+
class TestBertModel(BaseTest):
53+
class_model = BertModel
54+
55+
def test_predict(self):
56+
result = self.model.predict(list(self.model.x_test))
57+
assert [r['prediction'] for r in result] == [1, 1, 0, 0, 0, 0]
58+
59+
def test_confusion_matrix(self):
60+
self.model.confusion_matrix()
61+
62+
def test_optuna_train(self):
63+
self.model.optuna_train(n_trials=5)
64+
65+
class TestRobertaModel(BaseTest):
66+
class_model = RobertaModel
67+
68+
def test_optuna_train(self):
69+
self.model.optuna_train(n_trials=5)
70+
71+
def test_predict(self):
72+
result = self.model.predict(list(self.model.x_test))
73+
print(result, self.model.y_test.values)
74+
assert [r['prediction'] for r in result] == [0, 0, 1, 0, 0, 0]
75+
76+
77+
class TestLSTMModel(BaseTest):
78+
class_model = LSTMModel
79+
80+
def test_size_vocab(self):
81+
print(self.model.tokenizer.vocab_size)
82+
83+
84+
def test_predict(self):
85+
result = self.model.predict(list(self.model.x_test))
86+
assert result.tolist() == [1, 0, 0, 0, 0, 0]
87+
88+
89+
class TestRandomForestModel(BaseTest):
90+
class_model = RandomForestModel
91+
92+
93+
class TestServer:
94+
95+
@classmethod
96+
def setup_class(cls):
97+
cls.client = TestClient(app)
98+
99+
def test_main(self):
100+
rep = self.client.get("/")
101+
assert rep.status_code == 200
102+
103+
def test_predict(self):
104+
response = self.client.post("/predict", json=[{"text": "hello world"}])
105+
assert response.status_code == 200
106+
payload = response.json()
107+
assert payload["status"] == "processing"
108+
task_id = payload["task_id"]
109+
response = self.client.get(f"/get_result/{task_id}")
110+
payload = response.json()
111+
while payload["status"] == "processing":
112+
response = self.client.get(f"/get_result/{task_id}")
113+
time.sleep(1)
114+
payload = response.json()
115+
print(payload)
116+
# assert response.json() == {}

templates/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ <h2 class="text-justify m-2"> Bienvenue Sur TweetSentimentPredict </h2>
4747
body: JSON.stringify(payload)
4848
}).then(res => res.json()).then(res => {
4949
button.disabled = true;
50-
ws = new WebSocket("ws://127.0.0.1:5000/ws/" + res.task_id)
50+
ws = new WebSocket(`ws://${window.location.host}/ws/` + res.task_id)
5151
ws.onmessage = (event) => {
5252
const element = document.getElementById("result")
5353
const payload = JSON.parse(event.data)

0 commit comments

Comments
 (0)