22import os
33import numpy as np
44import math
5+ import time
56
67import torchvision .transforms as transforms
78from torchvision .utils import save_image
2728parser .add_argument ("--img_size" , type = int , default = 32 , help = "size of each image dimension" )
2829parser .add_argument ("--channels" , type = int , default = 1 , help = "number of image channels" )
2930parser .add_argument ("--sample_interval" , type = int , default = 400 , help = "number of image channels" )
31+ parser .add_argument ('--inference' , action = 'store_true' , default = False )
32+ parser .add_argument ('--precision' , default = 'float32' , help = 'Precision, "float32" or "bfloat16"' )
33+ parser .add_argument ('--channels_last' , type = int , default = 1 , help = 'use channels last format' )
34+ parser .add_argument ('--num-iterations' , default = 100 , type = int )
3035opt = parser .parse_args ()
3136print (opt )
3237
3338img_shape = (opt .channels , opt .img_size , opt .img_size )
3439
3540cuda = True if torch .cuda .is_available () else False
41+ Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
3642
3743
3844def weights_init_normal (m ):
@@ -68,6 +74,8 @@ def __init__(self):
6874 def forward (self , noise ):
6975 out = self .l1 (noise )
7076 out = out .view (out .shape [0 ], 128 , self .init_size , self .init_size )
77+ if opt .channels_last :
78+ out = out .to (memory_format = torch .channels_last )
7179 img = self .conv_blocks (out )
7280 return img
7381
@@ -94,116 +102,167 @@ def __init__(self):
94102
95103 def forward (self , img ):
96104 out = self .down (img )
105+ if opt .channels_last :
106+ out = out .contiguous ()
97107 out = self .fc (out .view (out .size (0 ), - 1 ))
98- out = self .up (out .view (out .size (0 ), 64 , self .down_size , self .down_size ))
108+ out = out .view (out .size (0 ), 64 , self .down_size , self .down_size )
109+ if opt .channels_last :
110+ out = out .to (memory_format = torch .channels_last )
111+ out = self .up (out )
99112 return out
100113
114+ def main ():
115+ # Initialize generator and discriminator
116+ generator = Generator ()
117+ discriminator = Discriminator ()
118+
119+ if cuda :
120+ generator .cuda ()
121+ discriminator .cuda ()
122+ else :
123+ generator .cpu ()
124+ discriminator .cpu ()
125+
126+ # Initialize weights
127+ generator .apply (weights_init_normal )
128+ discriminator .apply (weights_init_normal )
129+ device = torch .device ('cuda' ) if cuda else torch .device ('cpu' )
130+ if opt .inference :
131+ print ("----------------Generation---------------" )
132+ if opt .precision == "bfloat16" :
133+ cm = torch .cuda .amp .autocast if cuda else torch .cpu .amp .autocast
134+ with cm ():
135+ generate (generator , device = device )
136+ else :
137+ generate (generator , device = device )
138+ else :
139+ print ("-------------------Train-----------------" )
140+ train (generator , discriminator )
141+
142+
143+ def generate (netG , device ):
144+ fixed_noise = Variable (Tensor (np .random .normal (0 , 1 , (10 ** 2 , opt .latent_dim ))))
145+ if opt .channels_last :
146+ netG_oob = netG
147+ try :
148+ netG_oob = netG_oob .to (memory_format = torch .channels_last )
149+ print ("[INFO] Use NHWC model" )
150+ except :
151+ print ("[WARN] Input NHWC failed! Use normal model" )
152+ netG = netG_oob
153+ else :
154+ fixed_noise = fixed_noise .to (device = device )
155+ netG .eval ()
156+
157+ total_iters = opt .num_iterations
158+ with torch .no_grad ():
159+ tic = time .time ()
160+ for i in range (total_iters ):
161+ fake = netG (fixed_noise )
162+ toc = time .time () - tic
163+ print ("Throughput: %.2f image/sec, batchsize: %d, latency = %.2f ms" % ((opt .num_iterations * opt .batch_size )/ toc , opt .batch_size , 1000 * toc / opt .num_iterations ))
101164
102- # Initialize generator and discriminator
103- generator = Generator ()
104- discriminator = Discriminator ()
105-
106- if cuda :
107- generator .cuda ()
108- discriminator .cuda ()
109-
110- # Initialize weights
111- generator .apply (weights_init_normal )
112- discriminator .apply (weights_init_normal )
113-
114- # Configure data loader
115- os .makedirs ("../../data/mnist" , exist_ok = True )
116- dataloader = torch .utils .data .DataLoader (
117- datasets .MNIST (
118- "../../data/mnist" ,
119- train = True ,
120- download = True ,
121- transform = transforms .Compose (
122- [transforms .Resize (opt .img_size ), transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])]
123- ),
124- ),
125- batch_size = opt .batch_size ,
126- shuffle = True ,
127- )
128-
129- # Optimizers
130- optimizer_G = torch .optim .Adam (generator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
131- optimizer_D = torch .optim .Adam (discriminator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
132-
133- Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
134165
135166# ----------
136167# Training
137168# ----------
138169
139- # BEGAN hyper parameters
140- gamma = 0.75
141- lambda_k = 0.001
142- k = 0.0
143-
144- for epoch in range (opt .n_epochs ):
145- for i , (imgs , _ ) in enumerate (dataloader ):
146-
147- # Configure input
148- real_imgs = Variable (imgs .type (Tensor ))
149-
150- # -----------------
151- # Train Generator
152- # -----------------
153-
154- optimizer_G .zero_grad ()
155-
156- # Sample noise as generator input
157- z = Variable (Tensor (np .random .normal (0 , 1 , (imgs .shape [0 ], opt .latent_dim ))))
158-
159- # Generate a batch of images
160- gen_imgs = generator (z )
161-
162- # Loss measures generator's ability to fool the discriminator
163- g_loss = torch .mean (torch .abs (discriminator (gen_imgs ) - gen_imgs ))
164-
165- g_loss .backward ()
166- optimizer_G .step ()
167-
168- # ---------------------
169- # Train Discriminator
170- # ---------------------
171-
172- optimizer_D .zero_grad ()
173-
174- # Measure discriminator's ability to classify real from generated samples
175- d_real = discriminator (real_imgs )
176- d_fake = discriminator (gen_imgs .detach ())
177-
178- d_loss_real = torch .mean (torch .abs (d_real - real_imgs ))
179- d_loss_fake = torch .mean (torch .abs (d_fake - gen_imgs .detach ()))
180- d_loss = d_loss_real - k * d_loss_fake
181-
182- d_loss .backward ()
183- optimizer_D .step ()
184-
185- # ----------------
186- # Update weights
187- # ----------------
188-
189- diff = torch .mean (gamma * d_loss_real - d_loss_fake )
190-
191- # Update weight term for fake samples
192- k = k + lambda_k * diff .item ()
193- k = min (max (k , 0 ), 1 ) # Constraint to interval [0, 1]
194-
195- # Update convergence metric
196- M = (d_loss_real + torch .abs (diff )).data [0 ]
197-
198- # --------------
199- # Log Progress
200- # --------------
201-
202- print (
203- "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
204- % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), g_loss .item (), M , k )
205- )
206-
207- batches_done = epoch * len (dataloader ) + i
208- if batches_done % opt .sample_interval == 0 :
209- save_image (gen_imgs .data [:25 ], "images/%d.png" % batches_done , nrow = 5 , normalize = True )
170+ def train (netG , netD ):
171+ # BEGAN hyper parameters
172+ gamma = 0.75
173+ lambda_k = 0.001
174+ k = 0.0
175+
176+ # Configure data loader
177+ os .makedirs ("../../data/mnist" , exist_ok = True )
178+ dataloader = torch .utils .data .DataLoader (
179+ datasets .MNIST (
180+ "../../data/mnist" ,
181+ train = True ,
182+ download = True ,
183+ transform = transforms .Compose (
184+ [transforms .Resize (opt .img_size ), transforms .ToTensor (), transforms .Normalize ([0.5 ], [0.5 ])]
185+ ),
186+ ),
187+ batch_size = opt .batch_size ,
188+ shuffle = True ,
189+ )
190+ # Optimizers
191+ optimizer_G = torch .optim .Adam (netG .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
192+ optimizer_D = torch .optim .Adam (netD .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
193+
194+ for epoch in range (opt .n_epochs ):
195+ for i , (imgs , _ ) in enumerate (dataloader ):
196+ if opt .channels_last :
197+ imgs_oob = imgs
198+ try :
199+ imgs_oob = imgs_oob .to (memory_format = torch .channels_last )
200+ print ("[INFO] Use NHWC input" )
201+ except :
202+ print ("[WARN] Input NHWC failed! Use normal input" )
203+ imgs = imgs_oob
204+ # Configure input
205+ real_imgs = Variable (imgs .type (Tensor ))
206+
207+ # -----------------
208+ # Train Generator
209+ # -----------------
210+
211+ optimizer_G .zero_grad ()
212+
213+ # Sample noise as generator input
214+ z = Variable (Tensor (np .random .normal (0 , 1 , (imgs .shape [0 ], opt .latent_dim ))))
215+
216+ # Generate a batch of images
217+ gen_imgs = netG (z )
218+
219+ # Loss measures generator's ability to fool the discriminator
220+ g_loss = torch .mean (torch .abs (netD (gen_imgs ) - gen_imgs ))
221+
222+ g_loss .backward ()
223+ optimizer_G .step ()
224+
225+ # ---------------------
226+ # Train Discriminator
227+ # ---------------------
228+
229+ optimizer_D .zero_grad ()
230+
231+ # Measure discriminator's ability to classify real from generated samples
232+ d_real = netD (real_imgs )
233+ d_fake = netD (gen_imgs .detach ())
234+
235+ d_loss_real = torch .mean (torch .abs (d_real - real_imgs ))
236+ d_loss_fake = torch .mean (torch .abs (d_fake - gen_imgs .detach ()))
237+ d_loss = d_loss_real - k * d_loss_fake
238+
239+ d_loss .backward ()
240+ optimizer_D .step ()
241+
242+ # ----------------
243+ # Update weights
244+ # ----------------
245+
246+ diff = torch .mean (gamma * d_loss_real - d_loss_fake )
247+
248+ # Update weight term for fake samples
249+ k = k + lambda_k * diff .item ()
250+ k = min (max (k , 0 ), 1 ) # Constraint to interval [0, 1]
251+
252+ # Update convergence metric
253+ M = (d_loss_real + torch .abs (diff )).data .item ()
254+
255+ # --------------
256+ # Log Progress
257+ # --------------
258+
259+ print (
260+ "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] -- M: %f, k: %f"
261+ % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), g_loss .item (), M , k )
262+ )
263+
264+ batches_done = epoch * len (dataloader ) + i
265+ if batches_done % opt .sample_interval == 0 :
266+ save_image (gen_imgs .data [:25 ], "images/%d.png" % batches_done , nrow = 5 , normalize = True )
267+ if __name__ == '__main__' :
268+ main ()
0 commit comments