Skip to content

Commit 138e855

Browse files
refactor final
1 parent d9ff88e commit 138e855

File tree

3 files changed

+47
-42
lines changed

3 files changed

+47
-42
lines changed

Dockerfile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ RUN pip3 install -r requirements.txt
3131
# Set the working directory
3232
WORKDIR /usr/src/app
3333

34-
# Run the tests to ensure everything is working correctly
35-
#RUN make test
36-
3734
# Set the entrypoint to run the pipeline via Makefile.run
3835
ENTRYPOINT ["make", "-f", "Makefile.run"]
3936

main.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,39 +7,39 @@
77
from nets import MobileNetV2
88
from assets.meta import IMAGENET_CATEGORIES
99

10+
11+
# Assume the image is already resized; no need for transforms.Resize
12+
def preprocess_image(resized_image_path):
13+
transform = transforms.Compose([
14+
transforms.ToTensor(), # Convert the image to a PyTorch tensor
15+
transforms.Normalize(
16+
mean=[0.485, 0.456, 0.406], # Normalize using the mean and std of ImageNet
17+
std=[0.229, 0.224, 0.225]
18+
),
19+
])
20+
21+
image = Image.open(resized_image_path)
22+
image = transform(image).unsqueeze(0) # Add a batch dimension
23+
return image, image.shape, type(image)
24+
25+
def inference(model, resized_image_path, output_file_path):
26+
model.eval()
27+
28+
input_image, image_shape, image_type = preprocess_image(resized_image_path)
29+
with torch.no_grad():
30+
output = model(input_image)
31+
32+
_, predicted_class = output.max(1)
33+
predicted_label = IMAGENET_CATEGORIES[predicted_class.item()]
34+
# Write output in text file
35+
with open(output_file_path, "w") as f:
36+
f.write(predicted_label)
37+
return predicted_label
38+
1039
def main():
1140
model = MobileNetV2()
1241
model.load_state_dict(torch.load("./model/weights/mobilenetv2.pt", weights_only=True)) # weights ported from torchvision
1342
model.float() # converting weights to float32
14-
15-
# Assume the image is already resized; no need for transforms.Resize
16-
def preprocess_image(resized_image_path):
17-
transform = transforms.Compose([
18-
transforms.ToTensor(), # Convert the image to a PyTorch tensor
19-
transforms.Normalize(
20-
mean=[0.485, 0.456, 0.406], # Normalize using the mean and std of ImageNet
21-
std=[0.229, 0.224, 0.225]
22-
),
23-
])
24-
25-
image = Image.open(resized_image_path)
26-
image = transform(image).unsqueeze(0) # Add a batch dimension
27-
return image, image.shape, type(image)
28-
29-
def inference(model, resized_image_path, output_file_path):
30-
model.eval()
31-
32-
input_image, image_shape, image_type = preprocess_image(resized_image_path)
33-
with torch.no_grad():
34-
output = model(input_image)
35-
36-
_, predicted_class = output.max(1)
37-
predicted_label = IMAGENET_CATEGORIES[predicted_class.item()]
38-
# Write output in text file
39-
with open(output_file_path, "w") as f:
40-
f.write(predicted_label)
41-
return predicted_label
42-
4343
# Input directory
4444
input_dir = "input"
4545
resized_image_path = os.path.join(input_dir, "resized_image.jpg")

test_main.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,47 @@
11
import unittest
22
import sys
3+
import os
34
sys.path.append("/usr/src/app/model")
45
from unittest import mock
56
from unittest.mock import patch
6-
from main import main
7+
import main
78
from nets import MobileNetV2
89
import torch
910

1011
class TestMainFunction(unittest.TestCase):
1112

13+
def __init__(self):
14+
self.input_dir = "input_raw"
15+
self.resized_image_path = os.path.join(self.input_dir, "image.jpg")
16+
self.output_dir = "output_raw"
17+
os.makedirs(self.output_dir, exist_ok=True) # Create output_raw if it doesn't exist
18+
self.output_file_path = os.path.join(self.output_dir, "output_prediction.txt")
19+
1220
def test_missing_input_image(self):
1321
# Test with a missing input image
1422
with self.assertRaises(FileNotFoundError):
15-
main().preprocess_image('/nonexistent/image.jpg')
23+
main.preprocess_image('/nonexistent/image.jpg')
1624

1725
def test_tensor_size(self):
18-
image, img_shape, img_type = main().preprocess_image('/input/resized_image.jpg')
26+
image, img_shape, img_type = main.preprocess_image(self.resized_image_path)
1927
print(img_shape, img_type)
20-
self.assertNotEquals(img_shape, (1,2,3), 'Not Equal Shape')
28+
self.assertNotEqual(img_shape, (1,2,3), 'Equal Shape')
2129

2230
def test_image_type(self):
23-
image, img_shape, img_type = main().preprocess_image('/input/resized_image.jpg')
31+
image, img_shape, img_type = main.preprocess_image(self.resized_image_path)
2432
print(img_shape, img_type)
25-
self.assertNotEquals(img_type, int, 'Not Equal Image Type')
33+
self.assertNotEqual(img_type, int, 'Equal Image Type')
2634

2735
def test_image_class(self):
2836
model = MobileNetV2()
29-
model.load_state_dict(torch.load("./model/weights/mobilenetv2.pt", weights_only=True)) # weights ported from torchvision
37+
model.load_state_dict(torch.load("/Users/aibekakhmetkazy/PycharmProjects/fse4ai_team_2/mobilenetv2-pytorch/weights/mobilenetv2.pt", weights_only=True)) # weights ported from torchvision
3038
model.float()
3139

32-
predicted_label = main().inference(model, '/input/resized_image.jpg',
33-
'/output_raw/output_prediction.txt')
40+
predicted_label = main.inference(model, self.resized_image_path,
41+
self.output_file_path)
3442

3543
print('Label in output:',predicted_label)
36-
self.assertNotEquals(predicted_label, 'Aar', 'Wrong classification')
44+
self.assertNotEqual(predicted_label, 'House', 'Correct classification')
3745

3846
if __name__ == '__main__':
3947
unittest.main()

0 commit comments

Comments
 (0)