|
1 | 1 | import os |
2 | 2 | from dataclasses import dataclass |
| 3 | +from typing import Any, Dict, Tuple |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import torch |
@@ -159,6 +160,103 @@ def get_data_loader( |
159 | 160 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" |
160 | 161 |
|
161 | 162 |
|
| 163 | +def test_overfit_accelerate_mnist_simple_gan(): |
| 164 | + @dataclass |
| 165 | + class GANModelConfig(TrainerConfig): |
| 166 | + epochs: int = 1 |
| 167 | + print_step: int = 2 |
| 168 | + training_seed: int = 666 |
| 169 | + |
| 170 | + class GANModel(TrainerModel): |
| 171 | + def __init__(self): |
| 172 | + super().__init__() |
| 173 | + data_shape = (1, 28, 28) |
| 174 | + self.generator = Generator(latent_dim=100, img_shape=data_shape) |
| 175 | + self.discriminator = Discriminator(img_shape=data_shape) |
| 176 | + |
| 177 | + def forward(self, x): |
| 178 | + ... |
| 179 | + |
| 180 | + def train_step(self, batch, criterion, optimizer_idx): |
| 181 | + imgs, _ = batch |
| 182 | + |
| 183 | + # sample noise |
| 184 | + z = torch.randn(imgs.shape[0], 100) |
| 185 | + z = z.type_as(imgs) |
| 186 | + |
| 187 | + # train discriminator |
| 188 | + if optimizer_idx == 0: |
| 189 | + imgs_gen = self.generator(z) |
| 190 | + logits = self.discriminator(imgs_gen.detach()) |
| 191 | + fake = torch.zeros(imgs.size(0), 1) |
| 192 | + fake = fake.type_as(imgs) |
| 193 | + loss_fake = criterion(logits, fake) |
| 194 | + |
| 195 | + valid = torch.ones(imgs.size(0), 1) |
| 196 | + valid = valid.type_as(imgs) |
| 197 | + logits = self.discriminator(imgs) |
| 198 | + loss_real = loss = criterion(logits, valid) |
| 199 | + loss = (loss_real + loss_fake) / 2 |
| 200 | + return {"model_outputs": logits}, {"loss": loss} |
| 201 | + |
| 202 | + # train generator |
| 203 | + if optimizer_idx == 1: |
| 204 | + imgs_gen = self.generator(z) |
| 205 | + |
| 206 | + valid = torch.ones(imgs.size(0), 1) |
| 207 | + valid = valid.type_as(imgs) |
| 208 | + |
| 209 | + logits = self.discriminator(imgs_gen) |
| 210 | + loss_real = criterion(logits, valid) |
| 211 | + return {"model_outputs": logits}, {"loss": loss_real} |
| 212 | + |
| 213 | + @torch.no_grad() |
| 214 | + def eval_step(self, batch, criterion, optimizer_idx): |
| 215 | + return self.train_step(batch, criterion, optimizer_idx) |
| 216 | + |
| 217 | + def get_optimizer(self): |
| 218 | + discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) |
| 219 | + generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) |
| 220 | + return [discriminator_optimizer, generator_optimizer] |
| 221 | + |
| 222 | + def get_criterion(self): |
| 223 | + return nn.BCELoss() |
| 224 | + |
| 225 | + def get_data_loader( |
| 226 | + self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 |
| 227 | + ): # pylint: disable=unused-argument |
| 228 | + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) |
| 229 | + dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) |
| 230 | + dataset.data = dataset.data[:64] |
| 231 | + dataset.targets = dataset.targets[:64] |
| 232 | + dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=False) |
| 233 | + return dataloader |
| 234 | + |
| 235 | + config = GANModelConfig() |
| 236 | + config.batch_size = 64 |
| 237 | + config.grad_clip = None |
| 238 | + config.training_seed = 333 |
| 239 | + |
| 240 | + model = GANModel() |
| 241 | + trainer = Trainer( |
| 242 | + TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None |
| 243 | + ) |
| 244 | + |
| 245 | + trainer.eval_epoch() |
| 246 | + loss_d1 = trainer.keep_avg_eval["avg_loss_0"] |
| 247 | + loss_g1 = trainer.keep_avg_eval["avg_loss_1"] |
| 248 | + |
| 249 | + trainer.config.epochs = 5 |
| 250 | + trainer.fit() |
| 251 | + loss_d2 = trainer.keep_avg_train["avg_loss_0"] |
| 252 | + loss_g2 = trainer.keep_avg_train["avg_loss_1"] |
| 253 | + |
| 254 | + print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") |
| 255 | + print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") |
| 256 | + assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" |
| 257 | + assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" |
| 258 | + |
| 259 | + |
162 | 260 | def test_overfit_manual_optimize_mnist_simple_gan(): |
163 | 261 | @dataclass |
164 | 262 | class GANModelConfig(TrainerConfig): |
@@ -390,7 +488,131 @@ def get_data_loader( |
390 | 488 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" |
391 | 489 |
|
392 | 490 |
|
| 491 | +def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan(): |
| 492 | + @dataclass |
| 493 | + class GANModelConfig(TrainerConfig): |
| 494 | + epochs: int = 1 |
| 495 | + print_step: int = 2 |
| 496 | + training_seed: int = 666 |
| 497 | + |
| 498 | + class GANModel(TrainerModel): |
| 499 | + def __init__(self): |
| 500 | + super().__init__() |
| 501 | + data_shape = (1, 28, 28) |
| 502 | + self.generator = Generator(latent_dim=100, img_shape=data_shape) |
| 503 | + self.discriminator = Discriminator(img_shape=data_shape) |
| 504 | + |
| 505 | + def train_step(): |
| 506 | + ... |
| 507 | + |
| 508 | + def forward(self, x): |
| 509 | + ... |
| 510 | + |
| 511 | + def optimize(self, batch, trainer): |
| 512 | + imgs, _ = batch |
| 513 | + |
| 514 | + # sample noise |
| 515 | + z = torch.randn(imgs.shape[0], 100) |
| 516 | + z = z.type_as(imgs) |
| 517 | + |
| 518 | + # train discriminator |
| 519 | + imgs_gen = self.generator(z) |
| 520 | + logits = self.discriminator(imgs_gen.detach()) |
| 521 | + fake = torch.zeros(imgs.size(0), 1) |
| 522 | + fake = fake.type_as(imgs) |
| 523 | + loss_fake = trainer.criterion(logits, fake) |
| 524 | + |
| 525 | + valid = torch.ones(imgs.size(0), 1) |
| 526 | + valid = valid.type_as(imgs) |
| 527 | + logits = self.discriminator(imgs) |
| 528 | + loss_real = trainer.criterion(logits, valid) |
| 529 | + loss_disc = (loss_real + loss_fake) / 2 |
| 530 | + |
| 531 | + # step dicriminator |
| 532 | + self.scaled_backward(loss_disc, trainer, trainer.optimizer[0]) |
| 533 | + |
| 534 | + if trainer.total_steps_done % trainer.grad_accum_steps == 0: |
| 535 | + trainer.optimizer[0].step() |
| 536 | + trainer.optimizer[0].zero_grad() |
| 537 | + |
| 538 | + # train generator |
| 539 | + imgs_gen = self.generator(z) |
| 540 | + |
| 541 | + valid = torch.ones(imgs.size(0), 1) |
| 542 | + valid = valid.type_as(imgs) |
| 543 | + |
| 544 | + logits = self.discriminator(imgs_gen) |
| 545 | + loss_gen = trainer.criterion(logits, valid) |
| 546 | + |
| 547 | + # step generator |
| 548 | + self.scaled_backward(loss_gen, trainer, trainer.optimizer[1]) |
| 549 | + if trainer.total_steps_done % trainer.grad_accum_steps == 0: |
| 550 | + trainer.optimizer[1].step() |
| 551 | + trainer.optimizer[1].zero_grad() |
| 552 | + return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} |
| 553 | + |
| 554 | + @torch.no_grad() |
| 555 | + def eval_step(self, batch, criterion): |
| 556 | + imgs, _ = batch |
| 557 | + |
| 558 | + # sample noise |
| 559 | + z = torch.randn(imgs.shape[0], 100) |
| 560 | + z = z.type_as(imgs) |
| 561 | + |
| 562 | + imgs_gen = self.generator(z) |
| 563 | + valid = torch.ones(imgs.size(0), 1) |
| 564 | + valid = valid.type_as(imgs) |
| 565 | + |
| 566 | + logits = self.discriminator(imgs_gen) |
| 567 | + loss_gen = trainer.criterion(logits, valid) |
| 568 | + return {"model_outputs": logits}, {"loss_gen": loss_gen} |
| 569 | + |
| 570 | + def get_optimizer(self): |
| 571 | + discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) |
| 572 | + generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) |
| 573 | + return [discriminator_optimizer, generator_optimizer] |
| 574 | + |
| 575 | + def get_criterion(self): |
| 576 | + return nn.BCELoss() |
| 577 | + |
| 578 | + def get_data_loader( |
| 579 | + self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 |
| 580 | + ): # pylint: disable=unused-argument |
| 581 | + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) |
| 582 | + dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) |
| 583 | + dataset.data = dataset.data[:64] |
| 584 | + dataset.targets = dataset.targets[:64] |
| 585 | + dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) |
| 586 | + return dataloader |
| 587 | + |
| 588 | + config = GANModelConfig() |
| 589 | + config.batch_size = 64 |
| 590 | + config.grad_clip = None |
| 591 | + |
| 592 | + model = GANModel() |
| 593 | + trainer = Trainer( |
| 594 | + TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None |
| 595 | + ) |
| 596 | + |
| 597 | + trainer.config.epochs = 1 |
| 598 | + trainer.fit() |
| 599 | + loss_d1 = trainer.keep_avg_train["avg_loss_disc"] |
| 600 | + loss_g1 = trainer.keep_avg_train["avg_loss_gen"] |
| 601 | + |
| 602 | + trainer.config.epochs = 5 |
| 603 | + trainer.fit() |
| 604 | + loss_d2 = trainer.keep_avg_train["avg_loss_disc"] |
| 605 | + loss_g2 = trainer.keep_avg_train["avg_loss_gen"] |
| 606 | + |
| 607 | + print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") |
| 608 | + print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") |
| 609 | + assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" |
| 610 | + assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" |
| 611 | + |
| 612 | + |
393 | 613 | if __name__ == "__main__": |
394 | 614 | test_overfit_mnist_simple_gan() |
| 615 | + test_overfit_accelerate_mnist_simple_gan() |
395 | 616 | test_overfit_manual_optimize_mnist_simple_gan() |
396 | 617 | test_overfit_manual_optimize_grad_accum_mnist_simple_gan() |
| 618 | + test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan() |
0 commit comments