-
Notifications
You must be signed in to change notification settings - Fork 8k
/
Copy pathtest_predict.py
35 lines (25 loc) · 973 Bytes
/
test_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math
from regression_model.predict import make_prediction
from regression_model.processing.data_management import load_dataset
def test_make_single_prediction():
# Given
test_data = load_dataset(file_name='test.csv')
single_test_input = test_data[0:1]
# When
subject = make_prediction(input_data=single_test_input)
# Then
assert subject is not None
assert isinstance(subject.get('predictions')[0], float)
assert math.ceil(subject.get('predictions')[0]) == 112512
def test_make_multiple_predictions():
# Given
test_data = load_dataset(file_name='test.csv')
original_data_length = len(test_data)
multiple_test_input = test_data
# When
subject = make_prediction(input_data=multiple_test_input)
# Then
assert subject is not None
assert len(subject.get('predictions')) == 1451
# We expect some rows to be filtered out
assert len(subject.get('predictions')) != original_data_length