-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAOD_dehaze.py
More file actions
executable file
·115 lines (99 loc) · 4.28 KB
/
Copy pathAOD_dehaze.py
File metadata and controls
executable file
·115 lines (99 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# The Reproduction of the AOD-Net End-to-End Dehazing Network
# Rain CongJyu CHEN
# Refs: https://github.com/MayankSingal/PyTorch-Image-Dehazing
# AOD-Net Dehazing Program
import glob
import numpy as np
import torch
import torchvision
from PIL import Image
# Set the training device for AOD-Net.
if torch.mps.is_available():
training_device = torch.device("mps")
elif torch.cuda.is_available():
training_device = torch.device("cuda")
else:
training_device = torch.device("cpu")
# AOD_net
class AODDehazeNet(torch.nn.Module):
def __init__(self, *args, **kwargs):
super(AODDehazeNet, self).__init__()
super().__init__(*args, **kwargs)
self.relu = torch.nn.ReLU(inplace=True)
self.e_conv1 = torch.nn.Conv2d(3, 3, 1, 1, 0, bias=True)
self.e_conv2 = torch.nn.Conv2d(3, 3, 3, 1, 1, bias=True)
self.e_conv3 = torch.nn.Conv2d(6, 3, 5, 1, 2, bias=True)
self.e_conv4 = torch.nn.Conv2d(6, 3, 7, 1, 3, bias=True)
self.e_conv5 = torch.nn.Conv2d(12, 3, 3, 1, 1, bias=True)
def forward(self, x):
# source = [x]
x1 = self.relu(self.e_conv1(x))
x2 = self.relu(self.e_conv2(x1))
concat1 = torch.cat((x1, x2), 1)
x3 = self.relu(self.e_conv3(concat1))
concat2 = torch.cat((x2, x3), 1)
x4 = self.relu(self.e_conv4(concat2))
concat3 = torch.cat((x1, x2, x3, x4), 1)
x5 = self.relu(self.e_conv5(concat3))
clean_image = self.relu((x5 * x) - x5 + 1)
return clean_image
def dehaze_image(image_path):
data_hazy = Image.open(image_path)
data_hazy = np.asarray(data_hazy) / 255.0
data_hazy = torch.from_numpy(data_hazy).float()
data_hazy = data_hazy.permute(2, 0, 1)
# data_hazy = data_hazy.cpu().unsqueeze(0)
data_hazy = data_hazy.to(training_device).unsqueeze(0)
# dehaze_net = AOD_net.dehaze_net().cpu()
dehaze_net = AODDehazeNet().to(training_device)
dehaze_net.load_state_dict(torch.load("AOD-net-snapshots/dehazer.pth"))
clean_image = dehaze_net(data_hazy)
# torchvision.utils.save_image(torch.cat(
# (data_hazy, clean_image), 0), "results/" + image_path.split("/")[-1])
torchvision.utils.save_image(
clean_image, "./test-data-aod/dehazed/" + image_path.split("/")[-1]
)
# Evaluation of the datasets.
# def dehaze_evaluate(input_path, output_path):
# input_path_list = os.listdir(input_path)
# if ".DS_Store" in input_path_list:
# input_path_list.remove(".DS_Store")
# elif input_path_list is not None:
# print("[ FAIL ] Original image path is empty.")
# output_path_list = os.listdir(output_path)
# if ".DS_Store" in output_path_list:
# output_path_list.remove(".DS_Store")
# elif output_path_list is not None:
# print("[ FAIL ] Dehazed image path is empty.")
# print("Image \tPSNR \tSSIM\n---------\t---------\t---------")
# for file_name in input_path_list:
# # original_image_path = os.path.join(input_path, file_name)
# original_image_path = os.path.join("./clear-image", file_name)
# dehazed_image_path = os.path.join(output_path, file_name)
# original_image = cv2.imread(original_image_path)
# original_image = original_image.astype("float32") / 255
# dehazed_image = cv2.imread(dehazed_image_path)
# dehazed_image = dehazed_image.astype("float32") / 255
# current_psnr = round(evaluation.compare_psnr(
# original_image, dehazed_image), 6)
# current_ssim = round(evaluation.compare_ssim(
# original_image, dehazed_image, win_size=7, data_range=255, channel_axis=2), 6)
# print(file_name, "\t", current_psnr, "\t", current_ssim)
if __name__ == "__main__":
# Notice for devices.
if torch.mps.is_available():
print("[ INFO ] Start process with MPS.\n")
elif torch.cuda.is_available():
print("[ INFO ] Start process with CUDA\n")
else:
print("[ INFO ] Start process with CPU\n")
test_list = glob.glob("./test-data-aod/hazy/*")
for image in test_list:
print("[ INFO ] Processing image: ", image)
dehaze_image(image)
# print(image, "done!")
print("[ INFO ] Success. All images done.")
# dehaze_evaluate(
# input_path="./hazed-image",
# output_path="./dehazed-image/AOD-net"
# )