-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy pathutil.py
More file actions
184 lines (141 loc) · 5.58 KB
/
util.py
File metadata and controls
184 lines (141 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import math
import torch
# 注意力计算函数
def attention(Q, K, V, mask):
# b句话,每句话50个词,每个词编码成32维向量,4个头,每个头分到8维向量
# Q,K,V = [b, 4, 50, 8]
# [b, 4, 50, 8] * [b, 4, 8, 50] -> [b, 4, 50, 50]
# Q,K矩阵相乘,求每个词相对其他所有词的注意力
score = torch.matmul(Q, K.permute(0, 1, 3, 2))
# 除以每个头维数的平方根,做数值缩放
score /= 8 ** 0.5
# mask遮盖,mask是true的地方都被替换成-inf,这样在计算softmax的时候,-inf会被压缩到0
# mask = [b, 1, 50, 50]
score = score.masked_fill_(mask, -float('inf'))
score = torch.softmax(score, dim=-1)
# 以注意力分数乘以V,得到最终的注意力结果
# [b, 4, 50, 50] * [b, 4, 50, 8] -> [b, 4, 50, 8]
score = torch.matmul(score, V)
# 每个头计算的结果合一
# [b, 4, 50, 8] -> [b, 50, 32]
score = score.permute(0, 2, 1, 3).reshape(-1, 50, 32)
return score
# 多头注意力计算层
class MultiHead(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_Q = torch.nn.Linear(32, 32)
self.fc_K = torch.nn.Linear(32, 32)
self.fc_V = torch.nn.Linear(32, 32)
self.out_fc = torch.nn.Linear(32, 32)
# 规范化之后,均值是0,标准差是1
# BN是取不同样本做归一化
# LN是取不同通道做归一化
# affine=True,elementwise_affine=True,指定规范化后,再计算一个线性映射
# norm = torch.nn.BatchNorm1d(num_features=4, affine=True)
# print(norm(torch.arange(32, dtype=torch.float32).reshape(2, 4, 4)))
"""
[[[-1.1761, -1.0523, -0.9285, -0.8047],
[-1.1761, -1.0523, -0.9285, -0.8047],
[-1.1761, -1.0523, -0.9285, -0.8047],
[-1.1761, -1.0523, -0.9285, -0.8047]],
[[ 0.8047, 0.9285, 1.0523, 1.1761],
[ 0.8047, 0.9285, 1.0523, 1.1761],
[ 0.8047, 0.9285, 1.0523, 1.1761],
[ 0.8047, 0.9285, 1.0523, 1.1761]]]"""
# norm = torch.nn.LayerNorm(normalized_shape=4, elementwise_affine=True)
# print(norm(torch.arange(32, dtype=torch.float32).reshape(2, 4, 4)))
"""
[[[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416]],
[[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416]]]"""
self.norm = torch.nn.LayerNorm(normalized_shape=32, elementwise_affine=True)
self.dropout = torch.nn.Dropout(p=0.1)
def forward(self, Q, K, V, mask):
# b句话,每句话50个词,每个词编码成32维向量
# Q,K,V = [b, 50, 32]
b = Q.shape[0]
# 保留下原始的Q,后面要做短接用
clone_Q = Q.clone()
# 规范化
Q = self.norm(Q)
K = self.norm(K)
V = self.norm(V)
# 线性运算,维度不变
# [b, 50, 32] -> [b, 50, 32]
K = self.fc_K(K)
V = self.fc_V(V)
Q = self.fc_Q(Q)
# 拆分成多个头
# b句话,每句话50个词,每个词编码成32维向量,4个头,每个头分到8维向量
# [b, 50, 32] -> [b, 4, 50, 8]
Q = Q.reshape(b, 50, 4, 8).permute(0, 2, 1, 3)
K = K.reshape(b, 50, 4, 8).permute(0, 2, 1, 3)
V = V.reshape(b, 50, 4, 8).permute(0, 2, 1, 3)
# 计算注意力
# [b, 4, 50, 8] -> [b, 50, 32]
score = attention(Q, K, V, mask)
# 计算输出,维度不变
# [b, 50, 32] -> [b, 50, 32]
score = self.dropout(self.out_fc(score))
# 短接
score = clone_Q + score
return score
# 位置编码层
class PositionEmbedding(torch.nn.Module):
def __init__(self):
super().__init__()
# pos是第几个词,i是第几个维度,d_model是维度总数
def get_pe(pos, i, d_model):
fenmu = 1e4 ** (i / d_model)
pe = pos / fenmu
if i % 2 == 0:
return math.sin(pe)
return math.cos(pe)
# 初始化位置编码矩阵
pe = torch.empty(50, 32)
for i in range(50):
for j in range(32):
pe[i, j] = get_pe(i, j, 32)
pe = pe.unsqueeze(0)
# 定义为不更新的常量
self.register_buffer('pe', pe)
# 词编码层
self.embed = torch.nn.Embedding(39, 32)
# 初始化参数
self.embed.weight.data.normal_(0, 0.1)
def forward(self, x):
# [8, 50] -> [8, 50, 32]
embed = self.embed(x)
# 词编码和位置编码相加
# [8, 50, 32] + [1, 50, 32] -> [8, 50, 32]
embed = embed + self.pe
return embed
# 全连接输出层
class FullyConnectedOutput(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(in_features=32, out_features=64),
torch.nn.ReLU(),
torch.nn.Linear(in_features=64, out_features=32),
torch.nn.Dropout(p=0.1),
)
self.norm = torch.nn.LayerNorm(normalized_shape=32,
elementwise_affine=True)
def forward(self, x):
# 保留下原始的x,后面要做短接用
clone_x = x.clone()
# 规范化
x = self.norm(x)
# 线性全连接运算
# [b, 50, 32] -> [b, 50, 32]
out = self.fc(x)
# 做短接
out = clone_x + out
return out