Skip to content

Commit 9a89288

Browse files
committed
add tests
1 parent 5d9afc1 commit 9a89288

1 file changed

Lines changed: 116 additions & 0 deletions

File tree

src/tests/test_ml.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#import lightgbm
2+
import pytest
3+
import time
4+
from fastapi.testclient import TestClient
5+
from ..ml import (
6+
LogisticRegressionModel,
7+
load_data,
8+
BertModel,
9+
RobertaModel,
10+
LSTMModel,
11+
RandomForestModel,
12+
LightGBMModel,
13+
)
14+
from transformers import PreTrainedModel
15+
from ..server import app
16+
17+
class BaseTest:
18+
file = "data/tweets_test_train.csv"
19+
class_model = None
20+
21+
@classmethod
22+
def setup_class(cls):
23+
df = load_data(cls.file)
24+
cls.model = cls.class_model(dataset=df)
25+
26+
def test_train(self):
27+
self.model.train()
28+
29+
def test_tokenizer(self):
30+
self.model.tokenizer.transform(self.model.x_train)
31+
32+
def test_preprocessing(self):
33+
self.model.preprocessing(self.model.x_train)
34+
35+
36+
class TestLogisticRegressionModel(BaseTest):
37+
class_model = LogisticRegressionModel
38+
39+
def test_predict(self):
40+
result = self.model.predict(list(self.model.x_test))
41+
print(result, self.model.y_test.values)
42+
assert result.tolist() == [0, 1, 0, 0, 0, 0]
43+
44+
45+
class TestLightGBMModel(BaseTest):
46+
class_model = LightGBMModel
47+
48+
def test_train(self):
49+
self.model.train()
50+
51+
52+
class TestBertModel(BaseTest):
53+
class_model = BertModel
54+
55+
def test_predict(self):
56+
result = self.model.predict(list(self.model.x_test))
57+
assert [r['prediction'] for r in result] == [1, 1, 0, 0, 0, 0]
58+
59+
def test_confusion_matrix(self):
60+
self.model.confusion_matrix()
61+
62+
def test_optuna_train(self):
63+
self.model.optuna_train(n_trials=5)
64+
65+
class TestRobertaModel(BaseTest):
66+
class_model = RobertaModel
67+
68+
def test_optuna_train(self):
69+
self.model.optuna_train(n_trials=5)
70+
71+
def test_predict(self):
72+
result = self.model.predict(list(self.model.x_test))
73+
print(result, self.model.y_test.values)
74+
assert [r['prediction'] for r in result] == [0, 0, 1, 0, 0, 0]
75+
76+
77+
class TestLSTMModel(BaseTest):
78+
class_model = LSTMModel
79+
80+
def test_size_vocab(self):
81+
print(self.model.tokenizer.vocab_size)
82+
83+
84+
def test_predict(self):
85+
result = self.model.predict(list(self.model.x_test))
86+
assert result.tolist() == [1, 0, 0, 0, 0, 0]
87+
88+
89+
class TestRandomForestModel(BaseTest):
90+
class_model = RandomForestModel
91+
92+
93+
class TestServer:
94+
95+
@classmethod
96+
def setup_class(cls):
97+
cls.client = TestClient(app)
98+
99+
def test_main(self):
100+
rep = self.client.get("/")
101+
assert rep.status_code == 200
102+
103+
def test_predict(self):
104+
response = self.client.post("/predict", json=[{"text": "hello world"}])
105+
assert response.status_code == 200
106+
payload = response.json()
107+
assert payload["status"] == "processing"
108+
task_id = payload["task_id"]
109+
response = self.client.get(f"/get_result/{task_id}")
110+
payload = response.json()
111+
while payload["status"] == "processing":
112+
response = self.client.get(f"/get_result/{task_id}")
113+
time.sleep(1)
114+
payload = response.json()
115+
print(payload)
116+
# assert response.json() == {}

0 commit comments

Comments
 (0)