-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestprediction.py
More file actions
101 lines (67 loc) · 2.65 KB
/
Copy pathtestprediction.py
File metadata and controls
101 lines (67 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import loader
import torch
import matplotlib.pyplot as plt
import random
def show_random(wandb_path):
test_dataset, _ = loader.get_test_dataset_wandb(wandb_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = loader.load_model(wandb_path)
model.to(device)
model.eval()
random_index = random.randint(0, len(test_dataset) - 1)
image, ground_truth = test_dataset[random_index]
image_on_device = image.unsqueeze(0).to(device)
with torch.no_grad():
prediction_on_device = model(image_on_device)
prediction = prediction_on_device.squeeze().to("cpu")
img_np = image.permute(1, 2, 0).cpu().numpy()
mask_np = ground_truth.squeeze().cpu().numpy()
pred_np = prediction.squeeze().cpu().numpy()
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(img_np)
ax[0].set_title(f"Original Image (Index: {random_index})")
ax[0].axis('off')
ax[1].imshow(mask_np, cmap='gray')
ax[1].set_title("Ground Truth")
ax[1].axis('off')
ax[2].imshow(pred_np, cmap='gray')
ax[2].set_title("Model Prediction")
ax[2].axis('off')
plt.tight_layout()
plt.show()
def show_random_double(wandb_path_1, wandb_path_2):
test_dataset, _ = loader.get_test_dataset_wandb(wandb_path_1)
device = "cuda" if torch.cuda.is_available() else "cpu"
model1 = loader.load_model(wandb_path_1)
model2 = loader.load_model(wandb_path_2)
model1.to(device)
model1.eval()
model2.to(device)
model2.eval()
random_index = random.randint(0, len(test_dataset) - 1)
image, ground_truth = test_dataset[random_index]
image_on_device = image.unsqueeze(0).to(device)
with torch.no_grad():
prediction_on_device_1 = model1(image_on_device)
prediction_on_device_2 = model2(image_on_device)
prediction_1 = prediction_on_device_1.squeeze().to("cpu")
prediction_2 = prediction_on_device_2.squeeze().to("cpu")
img_np = image.permute(1, 2, 0).cpu().numpy()
mask_np = ground_truth.squeeze().cpu().numpy()
pred_1_np = prediction_1.squeeze().cpu().numpy()
pred_2_np = prediction_2.squeeze().cpu().numpy()
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
ax[0].imshow(img_np)
ax[0].set_title(f"Original Image (Index: {random_index})")
ax[0].axis('off')
ax[1].imshow(mask_np, cmap='gray')
ax[1].set_title("Ground Truth")
ax[1].axis('off')
ax[2].imshow(pred_1_np, cmap='gray')
ax[2].set_title("UNet Prediction")
ax[2].axis('off')
ax[3].imshow(pred_2_np, cmap='gray')
ax[3].set_title("VNet Prediction")
ax[3].axis('off')
plt.tight_layout()
plt.show()