Skip to content

Commit b68ac29

Browse files
authored
Custom optimize for handling complex trainings (#89)
* Implement `isimplemened` * Replace `hasattr` with `isimplemented` * Add `torch.set_grad_emabled` toggles * Implement custom optimize * Update model * Add GAN training tests * Make style * Make lint * Do not expect `grad_norm` returned * Add training examples * Bump up to v0.0.21
1 parent 91d83a1 commit b68ac29

File tree

8 files changed

+974
-113
lines changed

8 files changed

+974
-113
lines changed

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ Prefer installing from Github as it is more stable.
2424
## Implementing a model
2525
Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py)
2626

27-
## Training a model
28-
See the test script [here](tests/test_train_mnist.py) training a basic MNIST model.
27+
28+
## Training a model with auto optimization
29+
See the [MNIST example](examples/train_mnist.py).
30+
31+
32+
## Training a model with advanced optimization
33+
See the [GAN training example](examples/train_simple_gan.py) with Gradient Accumulation
34+
2935

3036
## Training with Batch Size Finder
3137
see the test script [here](tests/test_train_batch_size_finder.py) for training with batch size finder.
@@ -95,6 +101,6 @@ trainer.fit()
95101
To add a new logger, you must subclass [BaseDashboardLogger](trainer/logging/base_dash_logger.py) and overload its functions.
96102

97103
## Anonymized Telemetry
98-
We constantly seek to improve 🐸 for the community. To understand the community's needs better and address them accordingly, we collect stripped-down anonymized usage stats when you run the trainer.
104+
We constantly seek to improve 🐸 for the community. To understand the community's needs better and address them accordingly, we collect stripped-down anonymized usage stats when you run the trainer.
99105

100-
Of course, if you don't want, you can opt out by setting the environment variable `TRAINER_TELEMETRY=0`.
106+
Of course, if you don't want, you can opt out by setting the environment variable `TRAINER_TELEMETRY=0`.

examples/train_mnist.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟.
3+
"""
4+
5+
import os
6+
from dataclasses import dataclass
7+
8+
import torch
9+
from torch import nn
10+
from torch.nn import functional as F
11+
from torch.utils.data import DataLoader
12+
from torchvision import transforms
13+
from torchvision.datasets import MNIST
14+
15+
from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs
16+
17+
18+
@dataclass
19+
class MnistModelConfig(TrainerConfig):
20+
optimizer: str = "Adam"
21+
lr: float = 0.001
22+
epochs: int = 1
23+
print_step: int = 1
24+
save_step: int = 5
25+
plot_step: int = 5
26+
dashboard_logger: str = "tensorboard"
27+
28+
29+
class MnistModel(TrainerModel):
30+
def __init__(self):
31+
super().__init__()
32+
33+
# mnist images are (1, 28, 28) (channels, height, width)
34+
self.layer_1 = nn.Linear(28 * 28, 128)
35+
self.layer_2 = nn.Linear(128, 256)
36+
self.layer_3 = nn.Linear(256, 10)
37+
38+
def forward(self, x):
39+
batch_size, _, _, _ = x.size()
40+
41+
# (b, 1, 28, 28) -> (b, 1*28*28)
42+
x = x.view(batch_size, -1)
43+
x = self.layer_1(x)
44+
x = F.relu(x)
45+
x = self.layer_2(x)
46+
x = F.relu(x)
47+
x = self.layer_3(x)
48+
49+
x = F.log_softmax(x, dim=1)
50+
return x
51+
52+
def train_step(self, batch, criterion):
53+
x, y = batch
54+
logits = self(x)
55+
loss = criterion(logits, y)
56+
return {"model_outputs": logits}, {"loss": loss}
57+
58+
def eval_step(self, batch, criterion):
59+
x, y = batch
60+
logits = self(x)
61+
loss = criterion(logits, y)
62+
return {"model_outputs": logits}, {"loss": loss}
63+
64+
@staticmethod
65+
def get_criterion():
66+
return torch.nn.NLLLoss()
67+
68+
def get_data_loader(
69+
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
70+
): # pylint: disable=unused-argument
71+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
72+
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
73+
dataset.data = dataset.data[:256]
74+
dataset.targets = dataset.targets[:256]
75+
dataloader = DataLoader(dataset, batch_size=config.batch_size)
76+
return dataloader
77+
78+
79+
def main():
80+
"""Run `MNIST` model training from scratch or from previous checkpoint."""
81+
# init args and config
82+
train_args = TrainerArgs()
83+
config = MnistModelConfig()
84+
85+
# init the model from config
86+
model = MnistModel()
87+
88+
# init the trainer and 🚀
89+
trainer = Trainer(
90+
train_args,
91+
config,
92+
config.output_path,
93+
model=model,
94+
train_samples=model.get_data_loader(config, None, False, None, None, None),
95+
eval_samples=model.get_data_loader(config, None, True, None, None, None),
96+
parse_command_line_args=True,
97+
)
98+
trainer.fit()
99+
100+
101+
if __name__ == "__main__":
102+
main()

examples/train_simple_gan.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""
2+
This example shows training of a simple GAN model with MNIST dataset using Gradient Accumulation and Advanced
3+
Optimization where you call optimizer steps manually.
4+
"""
5+
6+
import os
7+
from dataclasses import dataclass
8+
9+
import numpy as np
10+
import torch
11+
from torch import nn
12+
from torch.utils.data import DataLoader
13+
from torchvision import transforms
14+
from torchvision.datasets import MNIST
15+
16+
from trainer import Trainer, TrainerConfig, TrainerModel
17+
from trainer.trainer import TrainerArgs
18+
19+
is_cuda = torch.cuda.is_available()
20+
21+
22+
# pylint: skip-file
23+
24+
25+
class Generator(nn.Module):
26+
def __init__(self, latent_dim, img_shape):
27+
super().__init__()
28+
self.img_shape = img_shape
29+
30+
def block(in_feat, out_feat, normalize=True):
31+
layers = [nn.Linear(in_feat, out_feat)]
32+
if normalize:
33+
layers.append(nn.BatchNorm1d(out_feat, 0.8))
34+
layers.append(nn.LeakyReLU(0.2, inplace=True))
35+
return layers
36+
37+
self.model = nn.Sequential(
38+
*block(latent_dim, 128, normalize=False),
39+
*block(128, 256),
40+
*block(256, 512),
41+
*block(512, 1024),
42+
nn.Linear(1024, int(np.prod(img_shape))),
43+
nn.Tanh(),
44+
)
45+
46+
def forward(self, z):
47+
img = self.model(z)
48+
img = img.view(img.size(0), *self.img_shape)
49+
return img
50+
51+
52+
class Discriminator(nn.Module):
53+
def __init__(self, img_shape):
54+
super().__init__()
55+
56+
self.model = nn.Sequential(
57+
nn.Linear(int(np.prod(img_shape)), 512),
58+
nn.LeakyReLU(0.2, inplace=True),
59+
nn.Linear(512, 256),
60+
nn.LeakyReLU(0.2, inplace=True),
61+
nn.Linear(256, 1),
62+
nn.Sigmoid(),
63+
)
64+
65+
def forward(self, img):
66+
img_flat = img.view(img.size(0), -1)
67+
validity = self.model(img_flat)
68+
69+
return validity
70+
71+
72+
@dataclass
73+
class GANModelConfig(TrainerConfig):
74+
epochs: int = 1
75+
print_step: int = 2
76+
training_seed: int = 666
77+
78+
79+
class GANModel(TrainerModel):
80+
def __init__(self):
81+
super().__init__()
82+
data_shape = (1, 28, 28)
83+
self.generator = Generator(latent_dim=100, img_shape=data_shape)
84+
self.discriminator = Discriminator(img_shape=data_shape)
85+
86+
def forward(self, x):
87+
...
88+
89+
def optimize(self, batch, trainer):
90+
imgs, _ = batch
91+
92+
# sample noise
93+
z = torch.randn(imgs.shape[0], 100)
94+
z = z.type_as(imgs)
95+
96+
# train discriminator
97+
imgs_gen = self.generator(z)
98+
logits = self.discriminator(imgs_gen.detach())
99+
fake = torch.zeros(imgs.size(0), 1)
100+
fake = fake.type_as(imgs)
101+
loss_fake = trainer.criterion(logits, fake)
102+
103+
valid = torch.ones(imgs.size(0), 1)
104+
valid = valid.type_as(imgs)
105+
logits = self.discriminator(imgs)
106+
loss_real = trainer.criterion(logits, valid)
107+
loss_disc = (loss_real + loss_fake) / 2
108+
109+
# step dicriminator
110+
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])
111+
112+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
113+
trainer.optimizer[0].step()
114+
trainer.optimizer[0].zero_grad()
115+
116+
# train generator
117+
imgs_gen = self.generator(z)
118+
119+
valid = torch.ones(imgs.size(0), 1)
120+
valid = valid.type_as(imgs)
121+
122+
logits = self.discriminator(imgs_gen)
123+
loss_gen = trainer.criterion(logits, valid)
124+
125+
# step generator
126+
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
127+
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
128+
trainer.optimizer[1].step()
129+
trainer.optimizer[1].zero_grad()
130+
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
131+
132+
@torch.no_grad()
133+
def eval_step(self, batch, criterion):
134+
imgs, _ = batch
135+
136+
# sample noise
137+
z = torch.randn(imgs.shape[0], 100)
138+
z = z.type_as(imgs)
139+
140+
imgs_gen = self.generator(z)
141+
valid = torch.ones(imgs.size(0), 1)
142+
valid = valid.type_as(imgs)
143+
144+
logits = self.discriminator(imgs_gen)
145+
loss_gen = trainer.criterion(logits, valid)
146+
return {"model_outputs": logits}, {"loss_gen": loss_gen}
147+
148+
def get_optimizer(self):
149+
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
150+
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
151+
return [discriminator_optimizer, generator_optimizer]
152+
153+
def get_criterion(self):
154+
return nn.BCELoss()
155+
156+
def get_data_loader(
157+
self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
158+
): # pylint: disable=unused-argument
159+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
160+
dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
161+
dataset.data = dataset.data[:64]
162+
dataset.targets = dataset.targets[:64]
163+
dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
164+
return dataloader
165+
166+
167+
if __name__ == "__main__":
168+
169+
config = GANModelConfig()
170+
config.batch_size = 64
171+
config.grad_clip = None
172+
173+
model = GANModel()
174+
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
175+
trainer.config.epochs = 10
176+
trainer.fit()

0 commit comments

Comments
 (0)