-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
61 lines (51 loc) · 1.78 KB
/
Copy patheval.py
File metadata and controls
61 lines (51 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from numpy.core.fromnumeric import shape
import torch
from model import UNET
from utils import *
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import argparse
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMSIZE = 128
loader = transforms.Compose([transforms.Scale(IMSIZE), transforms.ToTensor()])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test for lane segmentation")
parser.add_argument(
"--image_path",
type=str,
default="",
help='Path to the image that needs to be tested. (default: "")',
)
parser.add_argument(
"--save_img",
type=str,
default="",
help='Path to the image that needs to be tested. (default: "")',
)
args = parser.parse_args()
if args.image_path == "":
raise ("no image path specified")
img_path = args.image_path
image_real = Image.open(img_path)
image = loader(image_real).float()
image = image.unsqueeze(0)
if DEVICE == "cuda":
image = image.cuda()
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
model.eval()
output = model(image)
probabilities = torch.sigmoid(output.squeeze(1))
predicted_mask = (probabilities >= 0.5).float() * 1
if DEVICE == "cuda":
predicted_mask = predicted_mask.cpu().numpy()
else:
predicted_mask = predicted_mask.numpy()
image_with_mask = apply_mask(image_real=image_real, predicted_mask=predicted_mask)
image_with_mask = Image.fromarray((image_with_mask).astype(np.uint8)).convert("RGB")
if args.save_img != "":
image_with_mask.save(args.save_img)
plt.imshow(image_with_mask)
plt.show()