Skip to content

Commit f307843

Browse files
committed
add inference, amp and channel_last supports for began
1 parent 36d3c77 commit f307843

File tree

2 files changed

+310
-186
lines changed

2 files changed

+310
-186
lines changed

implementations/began/began.py

Lines changed: 163 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import numpy as np
44
import math
5+
import time
56

67
import torchvision.transforms as transforms
78
from torchvision.utils import save_image
@@ -27,12 +28,17 @@
2728
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
2829
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
2930
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
31+
parser.add_argument('--inference', action='store_true', default=False)
32+
parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"')
33+
parser.add_argument('--channels_last', type=int, default=1, help='use channels last format')
34+
parser.add_argument('--num-iterations', default=100, type=int)
3035
opt = parser.parse_args()
3136
print(opt)
3237

3338
img_shape = (opt.channels, opt.img_size, opt.img_size)
3439

3540
cuda = True if torch.cuda.is_available() else False
41+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
3642

3743

3844
def weights_init_normal(m):
@@ -68,6 +74,8 @@ def __init__(self):
6874
def forward(self, noise):
6975
out = self.l1(noise)
7076
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
77+
if opt.channels_last:
78+
out = out.to(memory_format=torch.channels_last)
7179
img = self.conv_blocks(out)
7280
return img
7381

@@ -94,116 +102,167 @@ def __init__(self):
94102

95103
def forward(self, img):
96104
out = self.down(img)
105+
if opt.channels_last:
106+
out = out.contiguous()
97107
out = self.fc(out.view(out.size(0), -1))
98-
out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
108+
out = out.view(out.size(0), 64, self.down_size, self.down_size)
109+
if opt.channels_last:
110+
out = out.to(memory_format=torch.channels_last)
111+
out = self.up(out)
99112
return out
100113

114+
def main():
115+
# Initialize generator and discriminator
116+
generator = Generator()
117+
discriminator = Discriminator()
118+
119+
if cuda:
120+
generator.cuda()
121+
discriminator.cuda()
122+
else:
123+
generator.cpu()
124+
discriminator.cpu()
125+
126+
# Initialize weights
127+
generator.apply(weights_init_normal)
128+
discriminator.apply(weights_init_normal)
129+
device = torch.device('cuda') if cuda else torch.device('cpu')
130+
if opt.inference:
131+
print("----------------Generation---------------")
132+
if opt.precision == "bfloat16":
133+
cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast
134+
with cm():
135+
generate(generator, device=device)
136+
else:
137+
generate(generator, device=device)
138+
else:
139+
print("-------------------Train-----------------")
140+
train(generator, discriminator)
141+
142+
143+
def generate(netG, device):
144+
fixed_noise = Variable(Tensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim))))
145+
if opt.channels_last:
146+
netG_oob = netG
147+
try:
148+
netG_oob = netG_oob.to(memory_format=torch.channels_last)
149+
print("[INFO] Use NHWC model")
150+
except:
151+
print("[WARN] Input NHWC failed! Use normal model")
152+
netG = netG_oob
153+
else:
154+
fixed_noise = fixed_noise.to(device=device)
155+
netG.eval()
156+
157+
total_iters = opt.num_iterations
158+
with torch.no_grad():
159+
tic = time.time()
160+
for i in range(total_iters):
161+
fake = netG(fixed_noise)
162+
toc = time.time() - tic
163+
print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations))
101164

102-
# Initialize generator and discriminator
103-
generator = Generator()
104-
discriminator = Discriminator()
105-
106-
if cuda:
107-
generator.cuda()
108-
discriminator.cuda()
109-
110-
# Initialize weights
111-
generator.apply(weights_init_normal)
112-
discriminator.apply(weights_init_normal)
113-
114-
# Configure data loader
115-
os.makedirs("../../data/mnist", exist_ok=True)
116-
dataloader = torch.utils.data.DataLoader(
117-
datasets.MNIST(
118-
"../../data/mnist",
119-
train=True,
120-
download=True,
121-
transform=transforms.Compose(
122-
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
123-
),
124-
),
125-
batch_size=opt.batch_size,
126-
shuffle=True,
127-
)
128-
129-
# Optimizers
130-
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
131-
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
132-
133-
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
134165

135166
# ----------
136167
# Training
137168
# ----------
138169

139-
# BEGAN hyper parameters
140-
gamma = 0.75
141-
lambda_k = 0.001
142-
k = 0.0
143-
144-
for epoch in range(opt.n_epochs):
145-
for i, (imgs, _) in enumerate(dataloader):
146-
147-
# Configure input
148-
real_imgs = Variable(imgs.type(Tensor))
149-
150-
# -----------------
151-
# Train Generator
152-
# -----------------
153-
154-
optimizer_G.zero_grad()
155-
156-
# Sample noise as generator input
157-
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
158-
159-
# Generate a batch of images
160-
gen_imgs = generator(z)
161-
162-
# Loss measures generator's ability to fool the discriminator
163-
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))
164-
165-
g_loss.backward()
166-
optimizer_G.step()
167-
168-
# ---------------------
169-
# Train Discriminator
170-
# ---------------------
171-
172-
optimizer_D.zero_grad()
173-
174-
# Measure discriminator's ability to classify real from generated samples
175-
d_real = discriminator(real_imgs)
176-
d_fake = discriminator(gen_imgs.detach())
177-
178-
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
179-
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
180-
d_loss = d_loss_real - k * d_loss_fake
181-
182-
d_loss.backward()
183-
optimizer_D.step()
184-
185-
# ----------------
186-
# Update weights
187-
# ----------------
188-
189-
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
190-
191-
# Update weight term for fake samples
192-
k = k + lambda_k * diff.item()
193-
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
194-
195-
# Update convergence metric
196-
M = (d_loss_real + torch.abs(diff)).data[0]
197-
198-
# --------------
199-
# Log Progress
200-
# --------------
201-
202-
print(
203-
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
204-
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
205-
)
206-
207-
batches_done = epoch * len(dataloader) + i
208-
if batches_done % opt.sample_interval == 0:
209-
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
170+
def train(netG, netD):
171+
# BEGAN hyper parameters
172+
gamma = 0.75
173+
lambda_k = 0.001
174+
k = 0.0
175+
176+
# Configure data loader
177+
os.makedirs("../../data/mnist", exist_ok=True)
178+
dataloader = torch.utils.data.DataLoader(
179+
datasets.MNIST(
180+
"../../data/mnist",
181+
train=True,
182+
download=True,
183+
transform=transforms.Compose(
184+
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
185+
),
186+
),
187+
batch_size=opt.batch_size,
188+
shuffle=True,
189+
)
190+
# Optimizers
191+
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
192+
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
193+
194+
for epoch in range(opt.n_epochs):
195+
for i, (imgs, _) in enumerate(dataloader):
196+
if opt.channels_last:
197+
imgs_oob = imgs
198+
try:
199+
imgs_oob = imgs_oob.to(memory_format=torch.channels_last)
200+
print("[INFO] Use NHWC input")
201+
except:
202+
print("[WARN] Input NHWC failed! Use normal input")
203+
imgs = imgs_oob
204+
# Configure input
205+
real_imgs = Variable(imgs.type(Tensor))
206+
207+
# -----------------
208+
# Train Generator
209+
# -----------------
210+
211+
optimizer_G.zero_grad()
212+
213+
# Sample noise as generator input
214+
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
215+
216+
# Generate a batch of images
217+
gen_imgs = netG(z)
218+
219+
# Loss measures generator's ability to fool the discriminator
220+
g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs))
221+
222+
g_loss.backward()
223+
optimizer_G.step()
224+
225+
# ---------------------
226+
# Train Discriminator
227+
# ---------------------
228+
229+
optimizer_D.zero_grad()
230+
231+
# Measure discriminator's ability to classify real from generated samples
232+
d_real = netD(real_imgs)
233+
d_fake = netD(gen_imgs.detach())
234+
235+
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
236+
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
237+
d_loss = d_loss_real - k * d_loss_fake
238+
239+
d_loss.backward()
240+
optimizer_D.step()
241+
242+
# ----------------
243+
# Update weights
244+
# ----------------
245+
246+
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
247+
248+
# Update weight term for fake samples
249+
k = k + lambda_k * diff.item()
250+
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
251+
252+
# Update convergence metric
253+
M = (d_loss_real + torch.abs(diff)).data.item()
254+
255+
# --------------
256+
# Log Progress
257+
# --------------
258+
259+
print(
260+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
261+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
262+
)
263+
264+
batches_done = epoch * len(dataloader) + i
265+
if batches_done % opt.sample_interval == 0:
266+
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
267+
if __name__ == '__main__':
268+
main()

0 commit comments

Comments
 (0)