-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
204 lines (169 loc) · 10.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
class GODArecovery(nn.Module):
def __init__(self, squences_length, recovery_length):
super().__init__()
self.squences_length = squences_length
self.recovery_length = recovery_length
self.log_std_min = -20
self.log_std_max = 2
#恢复states,输入为[64,17,10]输出为[64,17,5]
self.recovery_states_mean = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
self.recovery_states_logstd = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
#恢复actions,输入为[64,6,10]输出为[64,6,5]
self.recovery_actions_mean = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
self.recovery_actions_logstd = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
#恢复actions,输入为[64,1,10]输出为[64,1,5]
self.recovery_rewards_mean = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
self.recovery_rewards_logstd = nn.Sequential(
nn.Linear(self.squences_length, self.recovery_length),
)
def forward(self, states, actions, rewards, target_states, target_actions, target_rewards):
rewards = rewards.reshape(64,10,1)
target_rewards = target_rewards.reshape(64,10,1)
#预测新序列的states、actions、rewards的均值
new_states_mean = self.recovery_states_mean(states.permute(0,2,1)).permute(0,2,1)#[64,5,17]
new_actions_mean = self.recovery_actions_mean(actions.permute(0,2,1)).permute(0,2,1)#[64,5,6]
new_rewards_mean = self.recovery_rewards_mean(rewards.permute(0,2,1)).permute(0,2,1)#[64,5,1]
#预测新序列的states、actions、rewards的标准差
new_states_logstd = self.recovery_states_logstd(states.permute(0,2,1)).permute(0,2,1)#[64,5,17]
new_states_logstd = torch.clamp(new_states_logstd, self.log_std_min, self.log_std_max)
new_states_stds = torch.exp(new_states_logstd)
new_actions_logstd = self.recovery_actions_logstd(actions.permute(0,2,1)).permute(0,2,1)#[64,5,6]
new_actions_logstd = torch.clamp(new_actions_logstd, self.log_std_min, self.log_std_max)
new_actions_stds = torch.exp(new_actions_logstd)
new_rewards_logstd = self.recovery_rewards_logstd(rewards.permute(0,2,1)).permute(0,2,1)#[64,5,1]
new_rewards_logstd = torch.clamp(new_rewards_logstd, self.log_std_min, self.log_std_max)
new_rewards_stds = torch.exp(new_rewards_logstd)
#化为分布形式
new_states_distribution = Normal(new_states_mean, new_states_stds)#Normal 表示构造一个正态分布
new_actions_distribution = Independent(TransformedDistribution(Normal(new_actions_mean, new_actions_stds), TanhTransform(cache_size=1)),1)#Normal 表示构造一个正态分布
new_rewards_distribution = Normal(new_rewards_mean,new_rewards_stds)
#查看分布预估是否准确
#将原动作输入新生成的分布中,表示从新生成的分布中采样出原动作的概率的对数。
#如果从新分布中采样出原动作的概率为1,说明这个新分布生成的非常好,此时的对数概率密度为0,也就是loss为0.如果小于1,就说明新分布生成的不够好,需要改进,因此就会有loss
new_states_log_probs = new_states_distribution.log_prob(target_states[:,5:10])#这个地方报错,因为我的target_states没有归一化
eps = torch.finfo(target_actions.dtype).eps
target_actions = torch.clamp(target_actions, -1+eps, 1-eps)
new_actions_log_probs = new_actions_distribution.log_prob(target_actions[:,5:10])
new_rewards_log_probs = new_rewards_distribution.log_prob(target_rewards[:,5:10])
return new_states_log_probs, new_actions_log_probs, new_rewards_log_probs
def get_value(self, states, actions, rewards, target_states, target_actions, target_rewards):
rewards = rewards.reshape(64,10,1)
target_rewards = target_rewards.reshape(64,10,1)
#预测新序列的states、actions、rewards的均值
new_states_mean = self.recovery_states_mean(states.permute(0,2,1)).permute(0,2,1)#[64,5,17]
new_actions_mean = self.recovery_actions_mean(actions.permute(0,2,1)).permute(0,2,1)#[64,5,6]
new_rewards_mean = self.recovery_rewards_mean(rewards.permute(0,2,1)).permute(0,2,1)#[64,5,1]
return new_states_mean, new_actions_mean, new_rewards_mean
class GuideVAE(nn.Module):
"""Implementation of GuideVAE"""
def __init__(self, feature_size, class_size, latent_size):
super(GuideVAE, self).__init__()
self.fc1x = nn.Linear(feature_size, 150)
self.fc1y = nn.Linear(class_size, 150)
self.fc2_mu = nn.Linear(300, latent_size)
self.fc2_log_std = nn.Linear(300, latent_size)
self.fc2_mu_ = nn.Linear(150, latent_size)
self.fc2_log_std_ = nn.Linear(150, latent_size)
self.fc3 = nn.Linear(latent_size + class_size, 300)
self.fc3_ = nn.Linear(latent_size , 300)
self.fc4_mu = nn.Linear(300, feature_size)
self.fc4_log_std = nn.Linear(300, feature_size)
def encode(self, x, y, flag = True):
if flag:
h1x = F.relu(self.fc1x(x)) # concat features and labels [64,120]->[64,150]
h1y = F.relu(self.fc1y(y))
h1 = torch.cat([h1x,h1y], dim=1)
z_mu = self.fc2_mu(h1)#[64,300]->[64,32]
z_log_std = self.fc2_log_std(h1)#[64,300]->[64,32]
else:
h1 = F.relu(self.fc1x(x)) # concat features and labels [64,120]->[64,200]
z_mu = self.fc2_mu_(h1)#[64,200]->[64,32]
z_log_std = self.fc2_log_std_(h1)#[64,200]->[64,32]
return z_mu, z_log_std
def decode(self, z, y, flag = True):
if flag:
h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) # concat latents and labels
recon_mu = torch.sigmoid(self.fc4_mu(h3)) # use sigmoid because the input image's pixel is between 0-1
recon_log_std = self.fc4_log_std(h3)
else:
h3 = F.relu(self.fc3_(z)) # concat latents and labels
recon_mu = torch.sigmoid(self.fc4_mu(h3)) # use sigmoid because the input image's pixel is between 0-1
recon_log_std = self.fc4_log_std(h3)
return recon_mu, recon_log_std
def reparametrize(self, z_mu, z_log_std):#重参数化
z_std = torch.exp(z_log_std)
z_eps = torch.randn_like(z_std) # simple from standard normal distribution
z = z_mu + z_eps * z_std
return z
def forward(self, x, y):
z1_mu, z1_log_std = self.encode(x, y, flag = True)
z1 = self.reparametrize(z1_mu, z1_log_std)
recon_mu, recon_log_std = self.decode(z1, y, flag = True)
recon_log_std = torch.clamp(recon_log_std, -20, 2)
recon_std = torch.exp(recon_log_std)
return recon_mu, recon_std, z1_mu, z1_log_std
def reconstruct(self, x):
z2_mu, z2_log_std = self.encode(x, 0, flag = False)
z2 = self.reparametrize(z2_mu, z2_log_std)
recon_mu, recon_log_std = self.decode(z2, 0, flag = False)
return z2_mu, z2_log_std
def loss_function(self, recon_mu, recon_std, y, z1_mu, z1_log_std, z2_mu, z2_log_std) -> torch.Tensor:
#将原动作输入新生成的分布中,表示从新生成的分布中采样出原动作的概率的对数。
#如果从新分布中采样出原动作的概率为1,说明这个新分布生成的非常好,此时的对数概率密度为0,也就是loss为0.如果小于1,就说明新分布生成的不够好,需要改进,因此就会有loss
recon_distribution = Normal(recon_mu, recon_std)
recon_log_probs = recon_distribution.log_prob(y)
recon_loss = -torch.mean(recon_log_probs) # use "mean" may have a bad effect on gradients 。VAE是既要使得E[log p(x|z)]尽可能地大,也要使得KL散度尽可能的小,这两个是
#相互独立的,但是他们都对训练起到很强的积极作用
kl_loss = torch.sum(z2_log_std-z1_log_std-0.5+(torch.exp(2*z1_log_std)+(z1_mu-z2_mu).pow(2))/(2*torch.exp(2*z2_log_std)))
loss = recon_loss + kl_loss
return loss
# class PredictVAE(nn.Module):
# """Implementation of predictVAE"""
# def __init__(self, feature_size, latent_size):
# super(PredictVAE, self).__init__()
# self.fc1 = nn.Linear(feature_size , 200)
# self.fc2_z2_mu = nn.Linear(200, latent_size)
# self.fc2_z2_log_std = nn.Linear(200, latent_size)
# self.fc3 = nn.Linear(latent_size, 200)
# self.fc4_mu = nn.Linear(200, feature_size)
# self.fc4_log_std = nn.Linear(200, feature_size)
# def encode(self, x):
# h1 = F.relu(self.fc1(x)) # concat features and labels [64,120]->[64,200]
# z2_mu = self.fc2_z2_mu(h1)#[64,200]->[64,32]
# z2_log_std = self.fc2_z2_log_std(h1)#[64,200]->[64,32]
# return z2_mu, z2_log_std
# def decode(self, z2):
# h3 = F.relu(self.fc3(z2)) # concat latents and labels
# recon_mu = torch.sigmoid(self.fc4_mu(h3)) # use sigmoid because the input image's pixel is between 0-1
# recon_log_std = self.fc4_log_std(h3)
# return recon_mu, recon_log_std
# def reparametrize(self, z2_mu, z2_log_std):#重参数化
# z2_std = torch.exp(z2_log_std)
# z2_eps = torch.randn_like(z2_std) # simple from standard normal distribution
# z2 = z2_mu + z2_eps * z2_std
# return z2
# def forward(self, x):
# z2_mu, z2_log_std = self.encode(x)
# z2 = self.reparametrize(z2_mu, z2_log_std)
# recon_mu, recon_log_std = self.decode(z2)
# return z2_mu, z2_log_std
# def loss_function(self, z1_mu, z1_log_std, z2_mu, z2_log_std) -> torch.Tensor:
# kl_loss = (z2_log_std-z1_log_std-0.5+(torch.exp(2*z1_log_std)+(z1_mu-z2_mu).pow(2))/(2*torch.exp(2*z2_log_std)))
# return kl_loss