-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
116 lines (83 loc) · 3.21 KB
/
utils.py
File metadata and controls
116 lines (83 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
cuda = True
DEVICE = torch.device("cuda" if cuda else "cpu")
"""
code for vae from
https://github.com/Jackson-Kang/Pytorch-VAE-tutorial/blob/master/01_Variational_AutoEncoder.ipynb
"""
class View(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.FC_input = nn.Linear(input_dim, hidden_dim)
self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_mean = nn.Linear(hidden_dim, latent_dim)
self.FC_var = nn.Linear(hidden_dim, latent_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
self.training = True
def forward(self, x):
h_ = self.LeakyReLU(self.FC_input(x))
h_ = self.LeakyReLU(self.FC_input2(h_))
mean = self.FC_mean(h_)
log_var = self.FC_var(h_) # encoder produces mean and log of variance
# (i.e., parateters of simple tractable normal distribution "q"
return mean, log_var
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
self.FC_output = nn.Linear(hidden_dim, output_dim)
self.LeakyReLU = nn.LeakyReLU(0.2)
def forward(self, x):
h = self.LeakyReLU(self.FC_hidden(x))
h = self.LeakyReLU(self.FC_hidden2(h))
x_hat = torch.sigmoid(self.FC_output(h))
return x_hat
class Model(nn.Module):
def __init__(self, Encoder, Decoder):
super(Model, self).__init__()
self.Encoder = Encoder
self.Decoder = Decoder
def reparameterization(self, mean, var):
epsilon = torch.randn_like(var).to(DEVICE) # sampling epsilon
z = mean + var * epsilon # reparameterization trick
return z
def forward(self, x):
mean, log_var = self.Encoder(x)
z = self.reparameterization(
mean, torch.exp(0.5 * log_var)
) # takes exponential function (log var -> var)
x_hat = self.Decoder(z)
return x_hat, mean, log_var
def loss_function(x, x_hat, mean, log_var):
reproduction_loss = nn.functional.binary_cross_entropy(
x_hat, x, reduction="sum"
)
KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
return reproduction_loss + KLD
def sampling(mean, var, samples):
mean = torch.cat(samples * [mean])
var = torch.cat(samples * [var])
epsilon = torch.randn_like(var).to(DEVICE) # sampling epsilon
alpha = torch.empty(var.shape).normal_(mean=1, std=0.5) # sampling alpha
z = mean * alpha + var * epsilon # sampling
return z
def show_image(x, batch_size, idx):
x = x.view(batch_size, 28, 28)
fig = plt.figure()
plt.imshow(x[idx].cpu().numpy())
plt.show()
plt.close()