-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_example_predictions.py
More file actions
executable file
·71 lines (57 loc) · 2 KB
/
plot_example_predictions.py
File metadata and controls
executable file
·71 lines (57 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python3
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
import os
# import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch
from PIL import Image
from torchgeo.datamodules import L7IrishDataModule
from torchgeo.datasets import unbind_samples
device = torch.device('cpu')
# Load weights
path = 'data/l7irish/checkpoint-epoch=26-val_loss=0.68.ckpt'
state_dict = torch.load(path, map_location=device)['state_dict']
state_dict = {key.replace('model.', ''): value for key, value in state_dict.items()}
# Initialize model
model = smp.Unet(encoder_name='resnet18', in_channels=9, classes=5)
model.to(device)
model.load_state_dict(state_dict)
# Initialize data loaders
datamodule = L7IrishDataModule(
root='data/l7irish', crs='epsg:3857', download=True, batch_size=1, patch_size=224
)
datamodule.setup('test')
i = 0
for batch in datamodule.test_dataloader():
image = batch['image']
mask = batch['mask']
image.to(device)
# Skip nodata pixels
if 0 in mask:
continue
# Skip boring images
if len(mask.unique()) < 4:
continue
# Make a prediction
prediction = model(image)
prediction = prediction.argmax(dim=1)
prediction.detach().to('cpu')
batch['prediction'] = prediction
for sample in unbind_samples(batch):
# Plot
# datamodule.test_dataset.plot(sample)
# plt.show()
path = f'data/l7irish_predictions/{i}'
print(f'Saving {path}...')
os.makedirs(path, exist_ok=True)
for key in ['image', 'mask', 'prediction']:
data = sample[key]
if key == 'image':
data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype('uint8')
Image.fromarray(data, 'RGB').save(f'{path}/{key}.png')
else:
data = data * 255 / 4
data = data.numpy().astype('uint8').squeeze()
Image.fromarray(data, 'L').save(f'{path}/{key}.png')
i += 1