|
1 | 1 | import unittest |
2 | 2 | import sys |
| 3 | +import os |
3 | 4 | sys.path.append("/usr/src/app/model") |
4 | 5 | from unittest import mock |
5 | 6 | from unittest.mock import patch |
6 | | -from main import main |
| 7 | +import main |
7 | 8 | from nets import MobileNetV2 |
8 | 9 | import torch |
9 | 10 |
|
10 | 11 | class TestMainFunction(unittest.TestCase): |
11 | 12 |
|
| 13 | + def __init__(self): |
| 14 | + self.input_dir = "input_raw" |
| 15 | + self.resized_image_path = os.path.join(self.input_dir, "image.jpg") |
| 16 | + self.output_dir = "output_raw" |
| 17 | + os.makedirs(self.output_dir, exist_ok=True) # Create output_raw if it doesn't exist |
| 18 | + self.output_file_path = os.path.join(self.output_dir, "output_prediction.txt") |
| 19 | + |
12 | 20 | def test_missing_input_image(self): |
13 | 21 | # Test with a missing input image |
14 | 22 | with self.assertRaises(FileNotFoundError): |
15 | | - main().preprocess_image('/nonexistent/image.jpg') |
| 23 | + main.preprocess_image('/nonexistent/image.jpg') |
16 | 24 |
|
17 | 25 | def test_tensor_size(self): |
18 | | - image, img_shape, img_type = main().preprocess_image('/input/resized_image.jpg') |
| 26 | + image, img_shape, img_type = main.preprocess_image(self.resized_image_path) |
19 | 27 | print(img_shape, img_type) |
20 | | - self.assertNotEquals(img_shape, (1,2,3), 'Not Equal Shape') |
| 28 | + self.assertNotEqual(img_shape, (1,2,3), 'Equal Shape') |
21 | 29 |
|
22 | 30 | def test_image_type(self): |
23 | | - image, img_shape, img_type = main().preprocess_image('/input/resized_image.jpg') |
| 31 | + image, img_shape, img_type = main.preprocess_image(self.resized_image_path) |
24 | 32 | print(img_shape, img_type) |
25 | | - self.assertNotEquals(img_type, int, 'Not Equal Image Type') |
| 33 | + self.assertNotEqual(img_type, int, 'Equal Image Type') |
26 | 34 |
|
27 | 35 | def test_image_class(self): |
28 | 36 | model = MobileNetV2() |
29 | | - model.load_state_dict(torch.load("./model/weights/mobilenetv2.pt", weights_only=True)) # weights ported from torchvision |
| 37 | + model.load_state_dict(torch.load("/Users/aibekakhmetkazy/PycharmProjects/fse4ai_team_2/mobilenetv2-pytorch/weights/mobilenetv2.pt", weights_only=True)) # weights ported from torchvision |
30 | 38 | model.float() |
31 | 39 |
|
32 | | - predicted_label = main().inference(model, '/input/resized_image.jpg', |
33 | | - '/output_raw/output_prediction.txt') |
| 40 | + predicted_label = main.inference(model, self.resized_image_path, |
| 41 | + self.output_file_path) |
34 | 42 |
|
35 | 43 | print('Label in output:',predicted_label) |
36 | | - self.assertNotEquals(predicted_label, 'Aar', 'Wrong classification') |
| 44 | + self.assertNotEqual(predicted_label, 'House', 'Correct classification') |
37 | 45 |
|
38 | 46 | if __name__ == '__main__': |
39 | 47 | unittest.main() |
0 commit comments