-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFFA.py
115 lines (94 loc) · 3.49 KB
/
FFA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.nn as nn
import torch
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
class PALayer(nn.Module):
def __init__(self, channel):
super(PALayer, self).__init__()
self.pa = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.pa(x)
return x * y
class CALayer(nn.Module):
def __init__(self, channel):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.ca = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.ca(y)
return x * y
class Block(nn.Module):
def __init__(self, conv, dim, kernel_size, ):
super(Block, self).__init__()
self.conv1 = conv(dim, dim, kernel_size, bias=True)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = conv(dim, dim, kernel_size, bias=True)
self.calayer = CALayer(dim)
self.palayer = PALayer(dim)
def forward(self, x):
res = self.act1(self.conv1(x))
res = res + x
res = self.conv2(res)
res = self.calayer(res)
res = self.palayer(res)
res += x
return res
class Group(nn.Module):
def __init__(self, conv, dim, kernel_size, blocks):
super(Group, self).__init__()
modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
modules.append(conv(dim, dim, kernel_size))
self.gp = nn.Sequential(*modules)
def forward(self, x):
res = self.gp(x)
res += x
return res
class FFA(nn.Module):
def __init__(self, gps, blocks, conv=default_conv):
super(FFA, self).__init__()
self.gps = gps
self.dim = 64
kernel_size = 3
pre_process = [conv(3, self.dim, kernel_size)]
assert self.gps == 3
self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.ca = nn.Sequential(*[
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True),
nn.Sigmoid()
])
self.palayer = PALayer(self.dim)
post_precess = [
conv(self.dim, self.dim, kernel_size),
conv(self.dim, 3, kernel_size)]
self.pre = nn.Sequential(*pre_process)
self.post = nn.Sequential(*post_precess)
def forward(self, x1):
x = self.pre(x1)
res1 = self.g1(x)
res2 = self.g2(res1)
res3 = self.g3(res2)
w = self.ca(torch.cat([res1, res2, res3], dim=1))
w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]
out = w[:, 0, ::] * res1 + w[:, 1, ::] * res2 + w[:, 2, ::] * res3
out = self.palayer(out)
x = self.post(out)
return x + x1
if __name__ == "__main__":
net = FFA(gps=3, blocks=19)
print(net)