Skip to content

Commit 6cb088d

Browse files
committed
add inference, amp and channel_last supports for began
1 parent 36d3c77 commit 6cb088d

File tree

2 files changed

+313
-186
lines changed

2 files changed

+313
-186
lines changed

implementations/began/began.py

Lines changed: 168 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import numpy as np
44
import math
5+
import time
56

67
import torchvision.transforms as transforms
78
from torchvision.utils import save_image
@@ -27,12 +28,17 @@
2728
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
2829
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
2930
parser.add_argument("--sample_interval", type=int, default=400, help="number of image channels")
31+
parser.add_argument('--inference', action='store_true', default=False)
32+
parser.add_argument('--precision', default='float32', help='Precision, "float32" or "bfloat16"')
33+
parser.add_argument('--channels_last', type=int, default=1, help='use channels last format')
34+
parser.add_argument('--num-iterations', default=100, type=int)
3035
opt = parser.parse_args()
3136
print(opt)
3237

3338
img_shape = (opt.channels, opt.img_size, opt.img_size)
3439

3540
cuda = True if torch.cuda.is_available() else False
41+
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
3642

3743

3844
def weights_init_normal(m):
@@ -68,6 +74,8 @@ def __init__(self):
6874
def forward(self, noise):
6975
out = self.l1(noise)
7076
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
77+
if opt.channels_last:
78+
out = out.to(memory_format=torch.channels_last)
7179
img = self.conv_blocks(out)
7280
return img
7381

@@ -94,116 +102,172 @@ def __init__(self):
94102

95103
def forward(self, img):
96104
out = self.down(img)
105+
if opt.channels_last:
106+
out = out.contiguous()
97107
out = self.fc(out.view(out.size(0), -1))
98-
out = self.up(out.view(out.size(0), 64, self.down_size, self.down_size))
108+
out = out.view(out.size(0), 64, self.down_size, self.down_size)
109+
if opt.channels_last:
110+
out = out.to(memory_format=torch.channels_last)
111+
out = self.up(out)
99112
return out
100113

114+
def main():
115+
# Initialize generator and discriminator
116+
generator = Generator()
117+
discriminator = Discriminator()
118+
119+
if cuda:
120+
generator.cuda()
121+
discriminator.cuda()
122+
else:
123+
generator.cpu()
124+
discriminator.cpu()
125+
126+
# Initialize weights
127+
generator.apply(weights_init_normal)
128+
discriminator.apply(weights_init_normal)
129+
device = torch.device('cuda') if cuda else torch.device('cpu')
130+
if opt.inference:
131+
print("----------------Generation---------------")
132+
if opt.precision == "bfloat16":
133+
cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast
134+
with cm():
135+
generate(generator, device=device)
136+
else:
137+
generate(generator, device=device)
138+
else:
139+
print("-------------------Train-----------------")
140+
if opt.precision == "bfloat16":
141+
cm = torch.cuda.amp.autocast if cuda else torch.cpu.amp.autocast
142+
with cm():
143+
train(generator, discriminator)
144+
else:
145+
train(generator, discriminator)
146+
147+
148+
def generate(netG, device):
149+
fixed_noise = Variable(Tensor(np.random.normal(0, 1, (10 ** 2, opt.latent_dim))))
150+
if opt.channels_last:
151+
netG_oob = netG
152+
try:
153+
netG_oob = netG_oob.to(memory_format=torch.channels_last)
154+
print("[INFO] Use NHWC model")
155+
except:
156+
print("[WARN] Input NHWC failed! Use normal model")
157+
netG = netG_oob
158+
else:
159+
fixed_noise = fixed_noise.to(device=device)
160+
netG.eval()
161+
162+
total_iters = opt.num_iterations
163+
with torch.no_grad():
164+
tic = time.time()
165+
for i in range(total_iters):
166+
fake = netG(fixed_noise)
167+
toc = time.time() - tic
168+
print("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms"%((opt.num_iterations*opt.batch_size)/toc, opt.batch_size, 1000*toc/opt.num_iterations))
101169

102-
# Initialize generator and discriminator
103-
generator = Generator()
104-
discriminator = Discriminator()
105-
106-
if cuda:
107-
generator.cuda()
108-
discriminator.cuda()
109-
110-
# Initialize weights
111-
generator.apply(weights_init_normal)
112-
discriminator.apply(weights_init_normal)
113-
114-
# Configure data loader
115-
os.makedirs("../../data/mnist", exist_ok=True)
116-
dataloader = torch.utils.data.DataLoader(
117-
datasets.MNIST(
118-
"../../data/mnist",
119-
train=True,
120-
download=True,
121-
transform=transforms.Compose(
122-
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
123-
),
124-
),
125-
batch_size=opt.batch_size,
126-
shuffle=True,
127-
)
128-
129-
# Optimizers
130-
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
131-
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
132-
133-
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
134170

135171
# ----------
136172
# Training
137173
# ----------
138174

139-
# BEGAN hyper parameters
140-
gamma = 0.75
141-
lambda_k = 0.001
142-
k = 0.0
143-
144-
for epoch in range(opt.n_epochs):
145-
for i, (imgs, _) in enumerate(dataloader):
146-
147-
# Configure input
148-
real_imgs = Variable(imgs.type(Tensor))
149-
150-
# -----------------
151-
# Train Generator
152-
# -----------------
153-
154-
optimizer_G.zero_grad()
155-
156-
# Sample noise as generator input
157-
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
158-
159-
# Generate a batch of images
160-
gen_imgs = generator(z)
161-
162-
# Loss measures generator's ability to fool the discriminator
163-
g_loss = torch.mean(torch.abs(discriminator(gen_imgs) - gen_imgs))
164-
165-
g_loss.backward()
166-
optimizer_G.step()
167-
168-
# ---------------------
169-
# Train Discriminator
170-
# ---------------------
171-
172-
optimizer_D.zero_grad()
173-
174-
# Measure discriminator's ability to classify real from generated samples
175-
d_real = discriminator(real_imgs)
176-
d_fake = discriminator(gen_imgs.detach())
177-
178-
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
179-
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
180-
d_loss = d_loss_real - k * d_loss_fake
181-
182-
d_loss.backward()
183-
optimizer_D.step()
184-
185-
# ----------------
186-
# Update weights
187-
# ----------------
188-
189-
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
190-
191-
# Update weight term for fake samples
192-
k = k + lambda_k * diff.item()
193-
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
194-
195-
# Update convergence metric
196-
M = (d_loss_real + torch.abs(diff)).data[0]
197-
198-
# --------------
199-
# Log Progress
200-
# --------------
201-
202-
print(
203-
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
204-
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
205-
)
206-
207-
batches_done = epoch * len(dataloader) + i
208-
if batches_done % opt.sample_interval == 0:
209-
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
175+
def train(netG, netD):
176+
# BEGAN hyper parameters
177+
gamma = 0.75
178+
lambda_k = 0.001
179+
k = 0.0
180+
181+
# Configure data loader
182+
os.makedirs("../../data/mnist", exist_ok=True)
183+
dataloader = torch.utils.data.DataLoader(
184+
datasets.MNIST(
185+
"../../data/mnist",
186+
train=True,
187+
download=True,
188+
transform=transforms.Compose(
189+
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
190+
),
191+
),
192+
batch_size=opt.batch_size,
193+
shuffle=True,
194+
)
195+
# Optimizers
196+
optimizer_G = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
197+
optimizer_D = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
198+
199+
for epoch in range(opt.n_epochs):
200+
for i, (imgs, _) in enumerate(dataloader):
201+
if opt.channels_last:
202+
imgs_oob = imgs
203+
try:
204+
imgs_oob = imgs_oob.to(memory_format=torch.channels_last)
205+
print("[INFO] Use NHWC input")
206+
except:
207+
print("[WARN] Input NHWC failed! Use normal input")
208+
imgs = imgs_oob
209+
# Configure input
210+
real_imgs = Variable(imgs.type(Tensor))
211+
212+
# -----------------
213+
# Train Generator
214+
# -----------------
215+
216+
optimizer_G.zero_grad()
217+
218+
# Sample noise as generator input
219+
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
220+
221+
# Generate a batch of images
222+
gen_imgs = netG(z)
223+
224+
# Loss measures generator's ability to fool the discriminator
225+
g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs))
226+
227+
g_loss.backward()
228+
optimizer_G.step()
229+
230+
# ---------------------
231+
# Train Discriminator
232+
# ---------------------
233+
234+
optimizer_D.zero_grad()
235+
236+
# Measure discriminator's ability to classify real from generated samples
237+
d_real = netD(real_imgs)
238+
d_fake = netD(gen_imgs.detach())
239+
240+
d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
241+
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
242+
d_loss = d_loss_real - k * d_loss_fake
243+
244+
d_loss.backward()
245+
optimizer_D.step()
246+
247+
# ----------------
248+
# Update weights
249+
# ----------------
250+
251+
diff = torch.mean(gamma * d_loss_real - d_loss_fake)
252+
253+
# Update weight term for fake samples
254+
k = k + lambda_k * diff.item()
255+
k = min(max(k, 0), 1) # Constraint to interval [0, 1]
256+
257+
# Update convergence metric
258+
M = (d_loss_real + torch.abs(diff)).data.item()
259+
260+
# --------------
261+
# Log Progress
262+
# --------------
263+
264+
print(
265+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
266+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), M, k)
267+
)
268+
269+
batches_done = epoch * len(dataloader) + i
270+
if batches_done % opt.sample_interval == 0:
271+
save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
272+
if __name__ == '__main__':
273+
main()

0 commit comments

Comments
 (0)