-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlayers.py
More file actions
88 lines (72 loc) · 3.09 KB
/
layers.py
File metadata and controls
88 lines (72 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn
import torch.nn.functional as F
class BatchNorm2d(nn.Module):
def __init__(self, num_features,eps=1e-5,momentum=0.1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.gamma = nn.Parameter(torch.ones(1,num_features,1,1))
self.beta = nn.Parameter(torch.zeros(1,num_features,1,1))
self.register_buffer('running_mean',torch.zeros(1,num_features,1,1))
self.register_buffer('running_var',torch.ones(1,num_features,1,1))
def forward(self,x):
if self.training:
batch_mean = x.mean(dim=(0,2,3),keepdim=True)
batch_var = x.var(dim=(0, 2, 3), keepdim=True, unbiased=False)
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
mean = batch_mean
var = batch_var
else:
mean = self.running_mean
var = self.running_var
x_norm = (x-mean)/(var+self.eps).sqrt()
return self.gamma * x_norm + self.beta
class Gelu(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class Conv2d(nn.Module):
def __init__(self, in_channels, out_channels,kernel_size,stride=1, padding=0, bias=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size)
self.stride = stride
self.padding = padding
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *self.kernel_size) * 0.1)
if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.register_parameter('bias', None)
def forward(self, x):
B, C, H, W = x.shape
kh, kw = self.kernel_size
if self.padding > 0:
x = F.pad(x, (self.padding, self.padding, self.padding, self.padding))
h_out = (H + 2 * self.padding - kh) // self.stride + 1
w_out = (W + 2 * self.padding - kw) // self.stride + 1
x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
weights = self.weight.view(self.out_channels, -1)
out = torch.matmul(weights, x_unfold)
out = out.view(B, self.out_channels, h_out, w_out)
if self.bias is not None:
out += self.bias.view(1, -1, 1, 1)
return out
class Linear(nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features) * (2 / in_features) ** 0.5)
if bias:
self.beta = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter('beta', None)
def forward(self, x):
out = torch.matmul(x, self.weight.t())
if self.beta is not None:
out += self.beta
return out