-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathmodel.py
More file actions
54 lines (45 loc) · 1.54 KB
/
model.py
File metadata and controls
54 lines (45 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from torch import nn, optim
from torch.nn import functional as F
class AE(nn.Module):
''' Autoencoder for dimensional reduction'''
def __init__(self,dim):
super(AE, self).__init__()
self.dim = dim
self.fc1 = nn.Linear(dim, 512)
self.fc2 = nn.Linear(512, 128)
self.fc3 = nn.Linear(128, 512)
self.fc4 = nn.Linear(512, dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return F.relu(self.fc2(h1))
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.relu(self.fc4(h3))
def forward(self, x):
z = self.encode(x.view(-1, self.dim))
return self.decode(z), z
class VAE(nn.Module):
''' Variational Autoencoder for dimensional reduction'''
def __init__(self,dim):
super(VAE, self).__init__()
self.dim = dim
self.fc1 = nn.Linear(dim, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, dim)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.dim))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar, z