-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_vae.py
More file actions
366 lines (313 loc) · 15.6 KB
/
train_vae.py
File metadata and controls
366 lines (313 loc) · 15.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import json
import copy
from collections import OrderedDict
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import torch.distributed as dist
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from tqdm.auto import tqdm
from omegaconf import OmegaConf
import wandb
from util.crop import center_crop_arr
from loss.losses import ReconstructionLoss_Single_Stage
from models.vae import AutoencoderKL, VAE_MODEL_DICT
import util.misc as misc
from engine import update_ema
def img2save(img):
return (img * 0.5 + 0.5).clamp(0, 1)
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def gather(tensor):
"""
Gather tensors from all workers.
"""
tensor = tensor.clone()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor /= dist.get_world_size()
return tensor.detach().item()
#################################################################################
# Training Loop #
#################################################################################
def main(args):
misc.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if global_rank == 0:
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
os.makedirs(args.log_dir, exist_ok=True)
wandb.init(
project="SimFlow",
name=f"{args.vae}",
config=args,
dir=args.log_dir,
)
# Load VAE and create an EMA of the VAE
vae_config = VAE_MODEL_DICT[args.vae]
vae_embed_dim = vae_config["embed_dim"]
ch_mult = vae_config["ch_mult"]
vae = AutoencoderKL(embed_dim=vae_embed_dim, ch_mult=ch_mult,
use_variational=args.use_variational, fixed_std=args.fixed_std,
ln=args.vae_ln).to(device)
if args.vae_path is not None:
vae.init_from_ckpt(args.vae_path)
if global_rank == 0:
print("Model = %s" % str(vae))
# following timm: set wd as 0 for bias and norm layers
n_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print("Number of trainable parameters: {}M".format(n_params / 1e6))
vae_without_ddp = vae
loss_cfg = OmegaConf.load(args.loss_cfg_path)
vae_loss_fn = ReconstructionLoss_Single_Stage(loss_cfg).to(device)
# Define the optimizers for SiT, VAE, and VAE loss function separately
optimizer_vae = torch.optim.AdamW(
vae.parameters(),
lr=args.vae_learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
optimizer_loss_fn = torch.optim.AdamW(
vae_loss_fn.parameters(),
lr=args.disc_learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Setup data
# augmentation following DiT and ADM
transform_train = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if global_rank == 0:
print(dataset_train)
print("Sampler_train = %s" % str(sampler_train))
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if args.distributed:
vae = torch.nn.parallel.DistributedDataParallel(vae, device_ids=[args.gpu])
vae_without_ddp = vae.module
vae_loss_fn = torch.nn.parallel.DistributedDataParallel(vae_loss_fn, device_ids=[args.gpu])
vae_loss_fn_without_ddp = vae_loss_fn.module
# resume training
global_step = 0
if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
vae_without_ddp.load_state_dict(checkpoint['vae'])
vae_params = list(vae_without_ddp.parameters())
vae_ema_state_dict = checkpoint['vae_ema']
vae_ema_params = [vae_ema_state_dict[name].cuda() for name, _ in vae_without_ddp.named_parameters()]
if global_rank == 0:
print("Resume checkpoint %s" % args.resume)
vae_loss_fn_without_ddp.discriminator.load_state_dict(checkpoint['vae_disc'])
optimizer_vae.load_state_dict(checkpoint['optimizer_vae'])
optimizer_loss_fn.load_state_dict(checkpoint['optimizer_loss_fn'])
global_step = checkpoint['steps']
if global_rank == 0:
print("With optim & sched!")
del checkpoint
else:
vae_params = list(vae_without_ddp.parameters())
vae_ema_params = copy.deepcopy(vae_params)
if global_rank == 0:
print("Training from scratch")
# if args.disc_pretrained_ckpt is not None:
# # Load the discriminator from a pretrained checkpoint if provided
# disc_ckpt = torch.load(args.disc_pretrained_ckpt, map_location=device)
# vae_loss_fn.discriminator.load_state_dict(disc_ckpt)
# if global_rank == 0:
# print(f"Loaded discriminator from {args.disc_pretrained_ckpt}")
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not (global_rank == 0),
)
# main training loop
for epoch in range(args.epochs):
for imgs, labels in data_loader_train:
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
# 1). Forward pass: VAE
posterior, z_ori, z, recon_image = vae(imgs)
z_ori_mean_log, z_ori_std_log, z_ori_min_log, z_ori_max_log = z_ori.mean(), z_ori.std(), z_ori.min(), z_ori.max()
z_mean_log, z_std_log, z_min_log, z_max_log = z.mean(), z.std(), z.min(), z.max()
# 2). Backward pass: VAE, compute the VAE loss, backpropagate, and update the VAE; Then, compute the discriminator loss and update the discriminator
vae_loss, vae_loss_dict = vae_loss_fn(imgs, recon_image, posterior, global_step, "generator")
vae_loss.backward()
grad_norm_vae = torch.nn.utils.clip_grad_norm_(vae.parameters(), args.max_grad_norm)
optimizer_vae.step()
optimizer_vae.zero_grad(set_to_none=True)
torch.cuda.synchronize()
# discriminator loss and update
d_loss, d_loss_dict = vae_loss_fn(imgs, recon_image, posterior, global_step, "discriminator")
d_loss.backward()
grad_norm_disc = torch.nn.utils.clip_grad_norm_(vae_loss_fn.parameters(), args.max_grad_norm)
optimizer_loss_fn.step()
optimizer_loss_fn.zero_grad(set_to_none=True)
torch.cuda.synchronize()
update_ema(vae_ema_params, vae_params, rate=args.ema_rate)
# Prepare the logs based on the current step
logs = {
"epoch": epoch,
# VAE loss
"vae_loss/vae_loss": gather(vae_loss),
"vae_loss/reconstruction_loss": gather(vae_loss_dict["reconstruction_loss"]),
"vae_loss/perceptual_loss": gather(vae_loss_dict["perceptual_loss"]),
"vae_loss/kl_loss": gather(vae_loss_dict["kl_loss"]),
"vae_loss/weighted_gan_loss": gather(vae_loss_dict["weighted_gan_loss"]),
"vae_loss/discriminator_factor": gather(vae_loss_dict["discriminator_factor"]),
"vae_loss/gan_loss": gather(vae_loss_dict["gan_loss"]),
"vae_loss/d_weight": gather(vae_loss_dict["d_weight"]),
# Statistics
"stats/z_ori_mean": gather(z_ori_mean_log),
"stats/z_ori_std": gather(z_ori_std_log),
"stats/z_ori_min": gather(z_ori_min_log),
"stats/z_ori_max": gather(z_ori_max_log),
"stats/z_mean": gather(z_mean_log),
"stats/z_std": gather(z_std_log),
"stats/z_min": gather(z_min_log),
"stats/z_max": gather(z_max_log),
# Gradient norm
"grad_norm/grad_norm_vae": gather(grad_norm_vae),
"grad_norm/grad_norm_disc": gather(grad_norm_disc),
# Discriminator loss
"d_loss/d_loss": gather(d_loss),
"d_loss/logits_real": gather(d_loss_dict["logits_real"]),
"d_loss/logits_fake": gather(d_loss_dict["logits_fake"]),
"d_loss/lecam_loss": gather(d_loss_dict["lecam_loss"]),
}
if global_rank == 0:
wandb.log(logs, step=global_step)
if global_step % args.checkpointing_steps == 0 or global_step + 1 == args.max_train_steps:
vae_ema_state_dict = copy.deepcopy(vae_without_ddp.state_dict())
for i, (name, _value) in enumerate(vae_without_ddp.named_parameters()):
assert name in vae_ema_state_dict
vae_ema_state_dict[name] = vae_ema_params[i]
checkpoint = {
"vae_ema": vae_ema_state_dict,
"vae": vae_without_ddp.state_dict(),
"vae_disc": vae_loss_fn_without_ddp.discriminator.state_dict(),
"optimizer_vae": optimizer_vae.state_dict(),
"optimizer_loss_fn": optimizer_loss_fn.state_dict(),
"args": args,
"steps": global_step,
}
torch.save(checkpoint, os.path.join(args.output_dir, "checkpoint-last.pth"))
torch.save(checkpoint, os.path.join(args.output_dir, f"checkpoint-{global_step}.pth"))
print(f"Saved checkpoint to {args.output_dir}")
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
wandb.log({"Original images": wandb.Image(img2save(imgs[:8]))}, step=global_step)
wandb.log({"Reconstructed": wandb.Image(img2save(recon_image[:8]))}, step=global_step)
# enter
progress_bar.update(1)
global_step += 1
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
if global_rank == 0:
print("Done!")
wandb.finish()
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Training")
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
# logging params
parser.add_argument("--output_dir", type=str, default="exps")
parser.add_argument("--log_dir", type=str, default="logs")
parser.add_argument("--sampling-steps", type=int, default=5000)
parser.add_argument("--continue-train-exp-dir", type=str, default=None)
# dataset params
parser.add_argument("--data_path", type=str, default="data")
parser.add_argument('--img_size', default=256, type=int, help='images input size')
parser.add_argument("--batch_size", type=int, default=256)
# optimization params
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--max-train-steps", type=int, default=400000)
parser.add_argument("--checkpointing-steps", type=int, default=50000)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
# seed params
parser.add_argument("--seed", type=int, default=0)
# cpu params
parser.add_argument("--num-workers", type=int, default=4)
# vae params
parser.add_argument("--vae", type=str, default="vae_f8d4")
parser.add_argument("--vae_path", type=str, default=None)
parser.add_argument('--use_variational', default="True", type=misc.str_to_bool, help='Whether to use the variational version of the VAE')
parser.add_argument('--fixed_std', default=1.0, type=float, help='Fixed standard deviation for the VAE')
parser.add_argument("--vae_ln", default="True", type=misc.str_to_bool, help='Whether to use layer norm in the VAE')
# vae loss params
parser.add_argument("--disc-pretrained-ckpt", type=str, default=None)
parser.add_argument("--loss-cfg-path", type=str, default="configs/l1_lpips_kl_gan.yaml")
# vae training params
parser.add_argument("--vae-learning-rate", type=float, default=1e-4)
parser.add_argument("--disc-learning-rate", type=float, default=1e-4)
parser.add_argument('--ema_rate', default=0.9999, type=float)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
args.log_dir = args.output_dir
main(args)