-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathtest_discriminator_backward.py
108 lines (93 loc) · 2.86 KB
/
test_discriminator_backward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import copy
import random
import torch
from torch import nn as nn
class ToyDiscriminator(nn.Module):
def __init__(self):
super(ToyDiscriminator, self).__init__()
self.conv0 = nn.Conv2d(3, 4, 3, 1, 1, bias=True)
self.bn0 = nn.BatchNorm2d(4, affine=True)
self.conv1 = nn.Conv2d(4, 4, 3, 1, 1, bias=True)
self.bn1 = nn.BatchNorm2d(4, affine=True)
self.linear = nn.Linear(4 * 6 * 6, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = self.lrelu(self.bn0(self.conv0(x)))
feat = self.lrelu(self.bn1(self.conv1(feat)))
feat = feat.view(feat.size(0), -1)
out = torch.sigmoid(self.linear(feat))
return out
def main():
# use fixed random seed
manual_seed = 999
random.seed(manual_seed)
torch.manual_seed(manual_seed)
img_real = torch.rand((1, 3, 6, 6))
img_fake = torch.rand((1, 3, 6, 6))
net_d_1 = ToyDiscriminator()
net_d_2 = copy.deepcopy(net_d_1)
net_d_1.train()
net_d_2.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
for k, v in net_d_1.named_parameters():
print(k, v.size())
###########################
# (1) Backward D network twice as the official tutorial does:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
###########################
net_d_1.zero_grad()
# real
output = net_d_1(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
loss_real.backward()
# fake
output = net_d_1(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss_fake.backward()
###########################
# (2) Backward D network once
###########################
net_d_2.zero_grad()
# real
output = net_d_2(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
# fake
output = net_d_2(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss = loss_real + loss_fake
loss.backward()
###########################
# Compare differences
###########################
for k1, k2 in zip(net_d_1.parameters(), net_d_2.parameters()):
print(torch.sum(torch.abs(k1.grad - k2.grad)))
if __name__ == '__main__':
main()
r"""Output:
conv0.weight torch.Size([4, 3, 3, 3])
conv0.bias torch.Size([4])
bn0.weight torch.Size([4])
bn0.bias torch.Size([4])
conv1.weight torch.Size([4, 4, 3, 3])
conv1.bias torch.Size([4])
bn1.weight torch.Size([4])
bn1.bias torch.Size([4])
linear.weight torch.Size([1, 144])
linear.bias torch.Size([1])
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
"""