Skip to content

Commit 0a3e1ec

Browse files
authored
Accelerate (#114)
* Refactor optimize * Add accelerate tests * Remove line * Make style * Print out accelerate * Fixup * Fixup * Fixup * Add mixed precision with accelerate * Fixup mixed precision * Make style * Print logs only for rank zero * Allow tf32 * Fixup * Fixup * Fix setting up logger * Update README * Fix linter * Fixup * Update test * Handle different types in prepare * Compute grad norm * Fix eval_epoch single run * Make style * Meh linter
1 parent 50f9d86 commit 0a3e1ec

File tree

9 files changed

+549
-84
lines changed

9 files changed

+549
-84
lines changed

.pylintrc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,9 @@ logging-modules=logging
404404
[MESSAGES CONTROL]
405405

406406
# Only show warnings with the listed confidence levels. Leave empty to show
407-
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
407+
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE,
408408
# UNDEFINED.
409409
confidence=HIGH,
410-
CONTROL_FLOW,
411410
INFERENCE,
412411
INFERENCE_FAILURE,
413412
UNDEFINED

README.md

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,65 @@ Prefer installing from Github as it is more stable.
2525
Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py)
2626

2727

28-
## Training a model with auto optimization
28+
## Training a model with auto-optimization
2929
See the [MNIST example](examples/train_mnist.py).
3030

3131

3232
## Training a model with advanced optimization
33+
With 👟 you can define the whole optimization cycle as you want as the in GAN example below. It enables more
34+
under-the-hood control and flexibility for more advanced training loops.
35+
36+
You just have to use the ```scaled_backward()``` function to handle mixed precision training.
37+
38+
```python
39+
...
40+
41+
def optimize(self, batch, trainer):
42+
imgs, _ = batch
43+
44+
# sample noise
45+
z = torch.randn(imgs.shape[0], 100)
46+
z = z.type_as(imgs)
47+
48+
# train discriminator
49+
imgs_gen = self.generator(z)
50+
logits = self.discriminator(imgs_gen.detach())
51+
fake = torch.zeros(imgs.size(0), 1)
52+
fake = fake.type_as(imgs)
53+
loss_fake = trainer.criterion(logits, fake)
54+
55+
valid = torch.ones(imgs.size(0), 1)
56+
valid = valid.type_as(imgs)
57+
logits = self.discriminator(imgs)
58+
loss_real = trainer.criterion(logits, valid)
59+
loss_disc = (loss_real + loss_fake) / 2
60+
61+
# step dicriminator
62+
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])
63+
64+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
65+
trainer.optimizer[0].step()
66+
trainer.optimizer[0].zero_grad()
67+
68+
# train generator
69+
imgs_gen = self.generator(z)
70+
71+
valid = torch.ones(imgs.size(0), 1)
72+
valid = valid.type_as(imgs)
73+
74+
logits = self.discriminator(imgs_gen)
75+
loss_gen = trainer.criterion(logits, valid)
76+
77+
# step generator
78+
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
79+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
80+
trainer.optimizer[1].step()
81+
trainer.optimizer[1].zero_grad()
82+
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
83+
84+
...
85+
```
86+
3387
See the [GAN training example](examples/train_simple_gan.py) with Gradient Accumulation
3488

3589

@@ -51,6 +105,18 @@ We don't use ```.spawn()``` to initiate multi-gpu training since it causes certa
51105
- ```.spawn()``` trains the model in subprocesses and the model in the main process is not updated.
52106
- DataLoader with N processes gets really slow when the N is large.
53107

108+
## Training with [Accelerate](https://huggingface.co/docs/accelerate/index)
109+
110+
Setting `use_accelerate` in `TrainingArgs` to `True` will enable training with Accelerate.
111+
112+
You can also use it for multi-gpu or distributed training.
113+
114+
```console
115+
CUDA_VISIBLE_DEVICES="0,1,2" accelerate launch --multi_gpu --num_processes 3 train_recipe_autoregressive_prompt.py
116+
```
117+
118+
See the [Accelerate docs](https://huggingface.co/docs/accelerate/basic_tutorials/launch).
119+
54120
## Adding a callback
55121
👟 Supports callbacks to customize your runs. You can either set callbacks in your model implementations or give them
56122
explicitly to the Trainer.

requirements.dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ coverage
33
isort
44
pytest
55
pylint
6+
accelerate # for testing

tests/test_train_gan.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from dataclasses import dataclass
3+
from typing import Any, Dict, Tuple
34

45
import numpy as np
56
import torch
@@ -159,6 +160,103 @@ def get_data_loader(
159160
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
160161

161162

163+
def test_overfit_accelerate_mnist_simple_gan():
164+
@dataclass
165+
class GANModelConfig(TrainerConfig):
166+
epochs: int = 1
167+
print_step: int = 2
168+
training_seed: int = 666
169+
170+
class GANModel(TrainerModel):
171+
def __init__(self):
172+
super().__init__()
173+
data_shape = (1, 28, 28)
174+
self.generator = Generator(latent_dim=100, img_shape=data_shape)
175+
self.discriminator = Discriminator(img_shape=data_shape)
176+
177+
def forward(self, x):
178+
...
179+
180+
def train_step(self, batch, criterion, optimizer_idx):
181+
imgs, _ = batch
182+
183+
# sample noise
184+
z = torch.randn(imgs.shape[0], 100)
185+
z = z.type_as(imgs)
186+
187+
# train discriminator
188+
if optimizer_idx == 0:
189+
imgs_gen = self.generator(z)
190+
logits = self.discriminator(imgs_gen.detach())
191+
fake = torch.zeros(imgs.size(0), 1)
192+
fake = fake.type_as(imgs)
193+
loss_fake = criterion(logits, fake)
194+
195+
valid = torch.ones(imgs.size(0), 1)
196+
valid = valid.type_as(imgs)
197+
logits = self.discriminator(imgs)
198+
loss_real = loss = criterion(logits, valid)
199+
loss = (loss_real + loss_fake) / 2
200+
return {"model_outputs": logits}, {"loss": loss}
201+
202+
# train generator
203+
if optimizer_idx == 1:
204+
imgs_gen = self.generator(z)
205+
206+
valid = torch.ones(imgs.size(0), 1)
207+
valid = valid.type_as(imgs)
208+
209+
logits = self.discriminator(imgs_gen)
210+
loss_real = criterion(logits, valid)
211+
return {"model_outputs": logits}, {"loss": loss_real}
212+
213+
@torch.no_grad()
214+
def eval_step(self, batch, criterion, optimizer_idx):
215+
return self.train_step(batch, criterion, optimizer_idx)
216+
217+
def get_optimizer(self):
218+
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
219+
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
220+
return [discriminator_optimizer, generator_optimizer]
221+
222+
def get_criterion(self):
223+
return nn.BCELoss()
224+
225+
def get_data_loader(
226+
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
227+
): # pylint: disable=unused-argument
228+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
229+
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
230+
dataset.data = dataset.data[:64]
231+
dataset.targets = dataset.targets[:64]
232+
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=False)
233+
return dataloader
234+
235+
config = GANModelConfig()
236+
config.batch_size = 64
237+
config.grad_clip = None
238+
config.training_seed = 333
239+
240+
model = GANModel()
241+
trainer = Trainer(
242+
TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
243+
)
244+
245+
trainer.eval_epoch()
246+
loss_d1 = trainer.keep_avg_eval["avg_loss_0"]
247+
loss_g1 = trainer.keep_avg_eval["avg_loss_1"]
248+
249+
trainer.config.epochs = 5
250+
trainer.fit()
251+
loss_d2 = trainer.keep_avg_train["avg_loss_0"]
252+
loss_g2 = trainer.keep_avg_train["avg_loss_1"]
253+
254+
print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
255+
print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
256+
assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
257+
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
258+
259+
162260
def test_overfit_manual_optimize_mnist_simple_gan():
163261
@dataclass
164262
class GANModelConfig(TrainerConfig):
@@ -390,7 +488,131 @@ def get_data_loader(
390488
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
391489

392490

491+
def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan():
492+
@dataclass
493+
class GANModelConfig(TrainerConfig):
494+
epochs: int = 1
495+
print_step: int = 2
496+
training_seed: int = 666
497+
498+
class GANModel(TrainerModel):
499+
def __init__(self):
500+
super().__init__()
501+
data_shape = (1, 28, 28)
502+
self.generator = Generator(latent_dim=100, img_shape=data_shape)
503+
self.discriminator = Discriminator(img_shape=data_shape)
504+
505+
def train_step():
506+
...
507+
508+
def forward(self, x):
509+
...
510+
511+
def optimize(self, batch, trainer):
512+
imgs, _ = batch
513+
514+
# sample noise
515+
z = torch.randn(imgs.shape[0], 100)
516+
z = z.type_as(imgs)
517+
518+
# train discriminator
519+
imgs_gen = self.generator(z)
520+
logits = self.discriminator(imgs_gen.detach())
521+
fake = torch.zeros(imgs.size(0), 1)
522+
fake = fake.type_as(imgs)
523+
loss_fake = trainer.criterion(logits, fake)
524+
525+
valid = torch.ones(imgs.size(0), 1)
526+
valid = valid.type_as(imgs)
527+
logits = self.discriminator(imgs)
528+
loss_real = trainer.criterion(logits, valid)
529+
loss_disc = (loss_real + loss_fake) / 2
530+
531+
# step dicriminator
532+
self.scaled_backward(loss_disc, trainer, trainer.optimizer[0])
533+
534+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
535+
trainer.optimizer[0].step()
536+
trainer.optimizer[0].zero_grad()
537+
538+
# train generator
539+
imgs_gen = self.generator(z)
540+
541+
valid = torch.ones(imgs.size(0), 1)
542+
valid = valid.type_as(imgs)
543+
544+
logits = self.discriminator(imgs_gen)
545+
loss_gen = trainer.criterion(logits, valid)
546+
547+
# step generator
548+
self.scaled_backward(loss_gen, trainer, trainer.optimizer[1])
549+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
550+
trainer.optimizer[1].step()
551+
trainer.optimizer[1].zero_grad()
552+
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
553+
554+
@torch.no_grad()
555+
def eval_step(self, batch, criterion):
556+
imgs, _ = batch
557+
558+
# sample noise
559+
z = torch.randn(imgs.shape[0], 100)
560+
z = z.type_as(imgs)
561+
562+
imgs_gen = self.generator(z)
563+
valid = torch.ones(imgs.size(0), 1)
564+
valid = valid.type_as(imgs)
565+
566+
logits = self.discriminator(imgs_gen)
567+
loss_gen = trainer.criterion(logits, valid)
568+
return {"model_outputs": logits}, {"loss_gen": loss_gen}
569+
570+
def get_optimizer(self):
571+
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
572+
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
573+
return [discriminator_optimizer, generator_optimizer]
574+
575+
def get_criterion(self):
576+
return nn.BCELoss()
577+
578+
def get_data_loader(
579+
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
580+
): # pylint: disable=unused-argument
581+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
582+
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
583+
dataset.data = dataset.data[:64]
584+
dataset.targets = dataset.targets[:64]
585+
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
586+
return dataloader
587+
588+
config = GANModelConfig()
589+
config.batch_size = 64
590+
config.grad_clip = None
591+
592+
model = GANModel()
593+
trainer = Trainer(
594+
TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
595+
)
596+
597+
trainer.config.epochs = 1
598+
trainer.fit()
599+
loss_d1 = trainer.keep_avg_train["avg_loss_disc"]
600+
loss_g1 = trainer.keep_avg_train["avg_loss_gen"]
601+
602+
trainer.config.epochs = 5
603+
trainer.fit()
604+
loss_d2 = trainer.keep_avg_train["avg_loss_disc"]
605+
loss_g2 = trainer.keep_avg_train["avg_loss_gen"]
606+
607+
print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
608+
print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
609+
assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
610+
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
611+
612+
393613
if __name__ == "__main__":
394614
test_overfit_mnist_simple_gan()
615+
test_overfit_accelerate_mnist_simple_gan()
395616
test_overfit_manual_optimize_mnist_simple_gan()
396617
test_overfit_manual_optimize_grad_accum_mnist_simple_gan()
618+
test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan()

trainer/generic_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def isimplemented(obj, method_name):
1313
"""Check if a method is implemented in a class."""
1414
if method_name in dir(obj) and callable(getattr(obj, method_name)):
1515
try:
16-
obj.__getattribute__(method_name)() # pylint: disable=unnecessary-dunder-call
16+
obj.__getattribute__(method_name)() # pylint: disable=bad-option-value, unnecessary-dunder-call
1717
except NotImplementedError:
1818
return False
1919
except: # pylint: disable=bare-except

0 commit comments

Comments
 (0)