22import os
33import numpy as np
44import math
5+ import time
56
67import torchvision .transforms as transforms
78from torchvision .utils import save_image
2728parser .add_argument ("--img_size" , type = int , default = 32 , help = "size of each image dimension" )
2829parser .add_argument ("--channels" , type = int , default = 1 , help = "number of image channels" )
2930parser .add_argument ("--sample_interval" , type = int , default = 400 , help = "number of image channels" )
31+ parser .add_argument ('--inference' , action = 'store_true' , default = False )
32+ parser .add_argument ('--precision' , default = 'float32' , help = 'Precision, "float32" or "bfloat16"' )
33+ parser .add_argument ('--channels_last' , type = int , default = 1 , help = 'use channels last format' )
34+ parser .add_argument ('--num-iterations' , default = 100 , type = int )
3035opt = parser .parse_args ()
3136print (opt )
3237
3338img_shape = (opt .channels , opt .img_size , opt .img_size )
3439
3540cuda = True if torch .cuda .is_available () else False
41+ Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
3642
3743
3844def weights_init_normal (m ):
@@ -68,6 +74,8 @@ def __init__(self):
6874 def forward (self , noise ):
6975 out = self .l1 (noise )
7076 out = out .view (out .shape [0 ], 128 , self .init_size , self .init_size )
77+ if opt .channels_last :
78+ out = out .to (memory_format = torch .channels_last )
7179 img = self .conv_blocks (out )
7280 return img
7381
@@ -94,116 +102,172 @@ def __init__(self):
94102
95103 def forward (self , img ):
96104 out = self .down (img )
105+ if opt .channels_last :
106+ out = out .contiguous ()
97107 out = self .fc (out .view (out .size (0 ), - 1 ))
98- out = self .up (out .view (out .size (0 ), 64 , self .down_size , self .down_size ))
108+ out = out .view (out .size (0 ), 64 , self .down_size , self .down_size )
109+ if opt .channels_last :
110+ out = out .to (memory_format = torch .channels_last )
111+ out = self .up (out )
99112 return out
100113
114+ def main ():
115+ # Initialize generator and discriminator
116+ generator = Generator ()
117+ discriminator = Discriminator ()
118+
119+ if cuda :
120+ generator .cuda ()
121+ discriminator .cuda ()
122+ else :
123+ generator .cpu ()
124+ discriminator .cpu ()
125+
126+ # Initialize weights
127+ generator .apply (weights_init_normal )
128+ discriminator .apply (weights_init_normal )
129+ device = torch .device ('cuda' ) if cuda else torch .device ('cpu' )
130+ if opt .inference :
131+ print ("----------------Generation---------------" )
132+ if opt .precision == "bfloat16" :
133+ cm = torch .cuda .amp .autocast if cuda else torch .cpu .amp .autocast
134+ with cm ():
135+ generate (generator , device = device )
136+ else :
137+ generate (generator , device = device )
138+ else :
139+ print ("-------------------Train-----------------" )
140+ if opt .precision == "bfloat16" :
141+ cm = torch .cuda .amp .autocast if cuda else torch .cpu .amp .autocast
142+ with cm ():
143+ train (generator , discriminator )
144+ else :
145+ train (generator , discriminator )
146+
147+
148+ def generate (netG , device ):
149+ fixed_noise = Variable (Tensor (np .random .normal (0 , 1 , (10 ** 2 , opt .latent_dim ))))
150+ if opt .channels_last :
151+ netG_oob = netG
152+ try :
153+ netG_oob = netG_oob .to (memory_format = torch .channels_last )
154+ print ("[INFO] Use NHWC model" )
155+ except :
156+ print ("[WARN] Input NHWC failed! Use normal model" )
157+ netG = netG_oob
158+ else :
159+ fixed_noise = fixed_noise .to (device = device )
160+ netG .eval ()
161+
162+ total_iters = opt .num_iterations
163+ with torch .no_grad ():
164+ tic = time .time ()
165+ for i in range (total_iters ):
166+ fake = netG (fixed_noise )
167+ toc = time .time () - tic
168+ print ("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms" % ((opt .num_iterations * opt .batch_size )/ toc , opt .batch_size , 1000 * toc / opt .num_iterations ))
101169
102- # Initialize generator and discriminator
103- generator = Generator ()
104- discriminator = Discriminator ()
105-
106- if cuda :
107- generator .cuda ()
108- discriminator .cuda ()
109-
110- # Initialize weights
111- generator .apply (weights_init_normal )
112- discriminator .apply (weights_init_normal )
113-
114- # Configure data loader
115- os .makedirs ("../../data/mnist" , exist_ok = True )
116- dataloader = torch .utils .data .DataLoader (
117- datasets .MNIST (
118- "../../data/mnist" ,
119- train = True ,
120- download = True ,
121- transform = transforms .Compose (
122- [transforms .Resize (opt .img_size ), transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])]
123- ),
124- ),
125- batch_size = opt .batch_size ,
126- shuffle = True ,
127- )
128-
129- # Optimizers
130- optimizer_G = torch .optim .Adam (generator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
131- optimizer_D = torch .optim .Adam (discriminator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
132-
133- Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
134170
135171# ----------
136172# Training
137173# ----------
138174
139- # BEGAN hyper parameters
140- gamma = 0.75
141- lambda_k = 0.001
142- k = 0.0
143-
144- for epoch in range (opt .n_epochs ):
145- for i , (imgs , _ ) in enumerate (dataloader ):
146-
147- # Configure input
148- real_imgs = Variable (imgs .type (Tensor ))
149-
150- # -----------------
151- # Train Generator
152- # -----------------
153-
154- optimizer_G .zero_grad ()
155-
156- # Sample noise as generator input
157- z = Variable (Tensor (np .random .normal (0 , 1 , (imgs .shape [0 ], opt .latent_dim ))))
158-
159- # Generate a batch of images
160- gen_imgs = generator (z )
161-
162- # Loss measures generator's ability to fool the discriminator
163- g_loss = torch .mean (torch .abs (discriminator (gen_imgs ) - gen_imgs ))
164-
165- g_loss .backward ()
166- optimizer_G .step ()
167-
168- # ---------------------
169- # Train Discriminator
170- # ---------------------
171-
172- optimizer_D .zero_grad ()
173-
174- # Measure discriminator's ability to classify real from generated samples
175- d_real = discriminator (real_imgs )
176- d_fake = discriminator (gen_imgs .detach ())
177-
178- d_loss_real = torch .mean (torch .abs (d_real - real_imgs ))
179- d_loss_fake = torch .mean (torch .abs (d_fake - gen_imgs .detach ()))
180- d_loss = d_loss_real - k * d_loss_fake
181-
182- d_loss .backward ()
183- optimizer_D .step ()
184-
185- # ----------------
186- # Update weights
187- # ----------------
188-
189- diff = torch .mean (gamma * d_loss_real - d_loss_fake )
190-
191- # Update weight term for fake samples
192- k = k + lambda_k * diff .item ()
193- k = min (max (k , 0 ), 1 ) # Constraint to interval [0, 1]
194-
195- # Update convergence metric
196- M = (d_loss_real + torch .abs (diff )).data [0 ]
197-
198- # --------------
199- # Log Progress
200- # --------------
201-
202- print (
203- "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
204- % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), g_loss .item (), M , k )
205- )
206-
207- batches_done = epoch * len (dataloader ) + i
208- if batches_done % opt .sample_interval == 0 :
209- save_image (gen_imgs .data [:25 ], "images/%d.png" % batches_done , nrow = 5 , normalize = True )
175+ def train (netG , netD ):
176+ # BEGAN hyper parameters
177+ gamma = 0.75
178+ lambda_k = 0.001
179+ k = 0.0
180+
181+ # Configure data loader
182+ os .makedirs ("../../data/mnist" , exist_ok = True )
183+ dataloader = torch .utils .data .DataLoader (
184+ datasets .MNIST (
185+ "../../data/mnist" ,
186+ train = True ,
187+ download = True ,
188+ transform = transforms .Compose (
189+ [transforms .Resize (opt .img_size ), transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])]
190+ ),
191+ ),
192+ batch_size = opt .batch_size ,
193+ shuffle = True ,
194+ )
195+ # Optimizers
196+ optimizer_G = torch .optim .Adam (netG .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
197+ optimizer_D = torch .optim .Adam (netD .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
198+
199+ for epoch in range (opt .n_epochs ):
200+ for i , (imgs , _ ) in enumerate (dataloader ):
201+ if opt .channels_last :
202+ imgs_oob = imgs
203+ try :
204+ imgs_oob = imgs_oob .to (memory_format = torch .channels_last )
205+ print ("[INFO] Use NHWC input" )
206+ except :
207+ print ("[WARN] Input NHWC failed! Use normal input" )
208+ imgs = imgs_oob
209+ # Configure input
210+ real_imgs = Variable (imgs .type (Tensor ))
211+
212+ # -----------------
213+ # Train Generator
214+ # -----------------
215+
216+ optimizer_G .zero_grad ()
217+
218+ # Sample noise as generator input
219+ z = Variable (Tensor (np .random .normal (0 , 1 , (imgs .shape [0 ], opt .latent_dim ))))
220+
221+ # Generate a batch of images
222+ gen_imgs = netG (z )
223+
224+ # Loss measures generator's ability to fool the discriminator
225+ g_loss = torch .mean (torch .abs (netD (gen_imgs ) - gen_imgs ))
226+
227+ g_loss .backward ()
228+ optimizer_G .step ()
229+
230+ # ---------------------
231+ # Train Discriminator
232+ # ---------------------
233+
234+ optimizer_D .zero_grad ()
235+
236+ # Measure discriminator's ability to classify real from generated samples
237+ d_real = netD (real_imgs )
238+ d_fake = netD (gen_imgs .detach ())
239+
240+ d_loss_real = torch .mean (torch .abs (d_real - real_imgs ))
241+ d_loss_fake = torch .mean (torch .abs (d_fake - gen_imgs .detach ()))
242+ d_loss = d_loss_real - k * d_loss_fake
243+
244+ d_loss .backward ()
245+ optimizer_D .step ()
246+
247+ # ----------------
248+ # Update weights
249+ # ----------------
250+
251+ diff = torch .mean (gamma * d_loss_real - d_loss_fake )
252+
253+ # Update weight term for fake samples
254+ k = k + lambda_k * diff .item ()
255+ k = min (max (k , 0 ), 1 ) # Constraint to interval [0, 1]
256+
257+ # Update convergence metric
258+ M = (d_loss_real + torch .abs (diff )).data .item ()
259+
260+ # --------------
261+ # Log Progress
262+ # --------------
263+
264+ print (
265+ "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
266+ % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), g_loss .item (), M , k )
267+ )
268+
269+ batches_done = epoch * len (dataloader ) + i
270+ if batches_done % opt .sample_interval == 0 :
271+ save_image (gen_imgs .data [:25 ], "images/%d.png" % batches_done , nrow = 5 , normalize = True )
272+ if __name__ == '__main__' :
273+ main ()
0 commit comments