Skip to content

Commit ae581a6

Browse files
vae
1 parent ac5dcd0 commit ae581a6

File tree

3 files changed

+140
-9
lines changed

3 files changed

+140
-9
lines changed

Diff for: ML/Pytorch/more_advanced/VAE/model.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class VariationalAutoEncoder(nn.Module):
6+
def __init__(self, input_dim, h_dim=200, z_dim=20):
7+
super().__init__()
8+
# encoder
9+
self.img_2hid = nn.Linear(input_dim, h_dim)
10+
self.hid_2mu = nn.Linear(h_dim, z_dim)
11+
self.hid_2sigma = nn.Linear(h_dim, z_dim)
12+
13+
# decoder
14+
self.z_2hid = nn.Linear(z_dim, h_dim)
15+
self.hid_2img = nn.Linear(h_dim, input_dim)
16+
17+
self.relu = nn.ReLU()
18+
19+
def encode(self, x):
20+
h = self.relu(self.img_2hid(x))
21+
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
22+
return mu, sigma
23+
24+
def decode(self, z):
25+
h = self.relu(self.z_2hid(z))
26+
return torch.sigmoid(self.hid_2img(h))
27+
28+
def forward(self, x):
29+
mu, sigma = self.encode(x)
30+
epsilon = torch.randn_like(sigma)
31+
z_new = mu + sigma*epsilon
32+
x_reconstructed = self.decode(z_new)
33+
return x_reconstructed, mu, sigma
34+
35+
36+
if __name__ == "__main__":
37+
x = torch.randn(4, 28*28)
38+
vae = VariationalAutoEncoder(input_dim=784)
39+
x_reconstructed, mu, sigma = vae(x)
40+
print(x_reconstructed.shape)
41+
print(mu.shape)
42+
print(sigma.shape)
43+
44+
45+

Diff for: ML/Pytorch/more_advanced/VAE/train.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
import torchvision.datasets as datasets # Standard datasets
3+
from tqdm import tqdm
4+
from torch import nn, optim
5+
from model import VariationalAutoEncoder
6+
from torchvision import transforms
7+
from torchvision.utils import save_image
8+
from torch.utils.data import DataLoader
9+
10+
# Configuration
11+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12+
INPUT_DIM = 784
13+
H_DIM = 200
14+
Z_DIM = 20
15+
NUM_EPOCHS = 10
16+
BATCH_SIZE = 32
17+
LR_RATE = 3e-4 # Karpathy constant
18+
19+
# Dataset Loading
20+
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
21+
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
22+
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
23+
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
24+
loss_fn = nn.BCELoss(reduction="sum")
25+
26+
# Start Training
27+
for epoch in range(NUM_EPOCHS):
28+
loop = tqdm(enumerate(train_loader))
29+
for i, (x, _) in loop:
30+
# Forward pass
31+
x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
32+
x_reconstructed, mu, sigma = model(x)
33+
34+
# Compute loss
35+
reconstruction_loss = loss_fn(x_reconstructed, x)
36+
kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
37+
38+
# Backprop
39+
loss = reconstruction_loss + kl_div
40+
optimizer.zero_grad()
41+
loss.backward()
42+
optimizer.step()
43+
loop.set_postfix(loss=loss.item())
44+
45+
46+
model = model.to("cpu")
47+
def inference(digit, num_examples=1):
48+
"""
49+
Generates (num_examples) of a particular digit.
50+
Specifically we extract an example of each digit,
51+
then after we have the mu, sigma representation for
52+
each digit we can sample from that.
53+
54+
After we sample we can run the decoder part of the VAE
55+
and generate examples.
56+
"""
57+
images = []
58+
idx = 0
59+
for x, y in dataset:
60+
if y == idx:
61+
images.append(x)
62+
idx += 1
63+
if idx == 10:
64+
break
65+
66+
encodings_digit = []
67+
for d in range(10):
68+
with torch.no_grad():
69+
mu, sigma = model.encode(images[d].view(1, 784))
70+
encodings_digit.append((mu, sigma))
71+
72+
mu, sigma = encodings_digit[digit]
73+
for example in range(num_examples):
74+
epsilon = torch.randn_like(sigma)
75+
z = mu + sigma * epsilon
76+
out = model.decode(z)
77+
out = out.view(-1, 1, 28, 28)
78+
save_image(out, f"generated_{digit}_ex{example}.png")
79+
80+
for idx in range(10):
81+
inference(idx, num_examples=5)
82+
83+
84+
85+
86+

Diff for: ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@ def __init__(self, embed_size, heads):
2323
self.head_dim * heads == embed_size
2424
), "Embedding size needs to be divisible by heads"
2525

26-
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
27-
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
28-
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
29-
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
26+
self.values = nn.Linear(embed_size, embed_size)
27+
self.keys = nn.Linear(embed_size, embed_size)
28+
self.queries = nn.Linear(embed_size, embed_size)
29+
self.fc_out = nn.Linear(embed_size, embed_size)
3030

3131
def forward(self, values, keys, query, mask):
3232
# Get number of training examples
3333
N = query.shape[0]
3434

3535
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
3636

37+
values = self.values(values) # (N, value_len, embed_size)
38+
keys = self.keys(keys) # (N, key_len, embed_size)
39+
queries = self.queries(query) # (N, query_len, embed_size)
40+
3741
# Split the embedding into self.heads different pieces
3842
values = values.reshape(N, value_len, self.heads, self.head_dim)
3943
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
40-
query = query.reshape(N, query_len, self.heads, self.head_dim)
41-
42-
values = self.values(values) # (N, value_len, heads, head_dim)
43-
keys = self.keys(keys) # (N, key_len, heads, head_dim)
44-
queries = self.queries(query) # (N, query_len, heads, heads_dim)
44+
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
4545

4646
# Einsum does matrix mult. for query*keys for each training example
4747
# with every other training example, don't be confused by einsum

0 commit comments

Comments
 (0)