Skip to content

Commit 4a7958d

Browse files
committed
more cleanup
1 parent 1c25f54 commit 4a7958d

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

train_dalle.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, ChineseTokenizer, YttmTokenizer
1919

2020
# libraries needed for webdataset support
21+
2122
import webdataset as wds
2223
from torchvision import transforms as T
2324
from PIL import Image
@@ -224,6 +225,8 @@ def cp_path_to_dir(cp_path, tag):
224225
using_deepspeed = \
225226
distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
226227

228+
is_root = distr_backend.is_root_worker()
229+
227230
# tokenizer
228231

229232
if exists(args.bpe_path):
@@ -275,7 +278,7 @@ def cp_path_to_dir(cp_path, tag):
275278
vae = DiscreteVAE(**vae_params)
276279
vae.load_state_dict(weights)
277280
else:
278-
if distr_backend.is_root_worker():
281+
if is_root:
279282
print('using pretrained VAE for encoding images to tokens')
280283
vae_params = None
281284

@@ -381,7 +384,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
381384
)
382385
assert len(ds) > 0, 'dataset is empty'
383386

384-
if distr_backend.is_root_worker():
387+
if is_root:
385388
if not ENABLE_WEBDATASET:
386389
print(f'{len(ds)} image-text pairs found for training')
387390

@@ -445,7 +448,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
445448

446449
# experiment tracker
447450

448-
if distr_backend.is_root_worker():
451+
if is_root:
449452

450453
model_config = dict(
451454
depth=DEPTH,
@@ -529,14 +532,14 @@ def save_model(path, epoch=0):
529532
if using_deepspeed:
530533
cp_dir = cp_path_to_dir(path, 'ds')
531534

532-
if KEEP_N_CHECKPOINTS is not None and distr_backend.is_root_worker():
535+
if KEEP_N_CHECKPOINTS is not None and is_root:
533536
checkpoints = sorted(glob(str(cp_dir / "global*")), key=os.path.getmtime, reverse=True)
534537
for checkpoint in checkpoints[KEEP_N_CHECKPOINTS:]:
535538
shutil.rmtree(checkpoint)
536539

537540
distr_dalle.save_checkpoint(cp_dir, client_state=save_obj)
538541

539-
if not distr_backend.is_root_worker():
542+
if not is_root:
540543
return
541544

542545
# Save auxiliary values so we can reuse the standard routine
@@ -554,7 +557,7 @@ def save_model(path, epoch=0):
554557
if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: # see https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
555558
return
556559

557-
if not distr_backend.is_root_worker():
560+
if not is_root:
558561
return
559562

560563
save_obj = {
@@ -566,19 +569,29 @@ def save_model(path, epoch=0):
566569

567570
torch.save(save_obj, path)
568571

572+
def save_artifact(model_config, model_path, name = 'trained-dalle'):
573+
model_artifact = wandb.Artifact(name, type='model', metadata=dict(model_config))
574+
model_artifact.add_file(model_path)
575+
run.log_artifact(model_artifact)
576+
569577
# training
570578

571579
# Saves a checkpoint before training begins to fail early when mis-configured.
572580
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
581+
573582
save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)
583+
574584
for epoch in range(resume_epoch, EPOCHS):
575585
if data_sampler:
576586
data_sampler.set_epoch(epoch)
587+
577588
for i, (text, images) in enumerate((dl if ENABLE_WEBDATASET else distr_dl)):
578-
if i % 10 == 0 and distr_backend.is_root_worker():
589+
if i % 10 == 0 and is_root:
579590
t = time.time()
591+
580592
if args.fp16:
581593
images = images.half()
594+
582595
text, images = map(lambda t: t.cuda(), (text, images))
583596

584597
loss = distr_dalle(text, images, return_loss=True)
@@ -598,7 +611,7 @@ def save_model(path, epoch=0):
598611

599612
log = {}
600613

601-
if i % 10 == 0 and distr_backend.is_root_worker():
614+
if i % 10 == 0 and is_root:
602615
print(epoch, i, f'loss - {avg_loss.item()}')
603616

604617
log = {
@@ -611,47 +624,41 @@ def save_model(path, epoch=0):
611624
if i % SAVE_EVERY_N_STEPS == 0:
612625
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
613626

614-
if i % 100 == 0:
615-
if distr_backend.is_root_worker():
616-
sample_text = text[:1]
617-
token_list = sample_text.masked_select(sample_text != 0).tolist()
618-
decoded_text = tokenizer.decode(token_list)
627+
if i % 100 == 0 and is_root:
628+
sample_text = text[:1]
629+
token_list = sample_text.masked_select(sample_text != 0).tolist()
630+
decoded_text = tokenizer.decode(token_list)
619631

620-
if not avoid_model_calls:
621-
# CUDA index errors when we don't guard this
622-
image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9
632+
if not avoid_model_calls:
633+
# CUDA index errors when we don't guard this
634+
image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9
623635

624-
if not avoid_model_calls:
625-
log['image'] = wandb.Image(image, caption=decoded_text)
636+
if not avoid_model_calls:
637+
log['image'] = wandb.Image(image, caption=decoded_text)
626638

627-
if i % 10 == 9 and distr_backend.is_root_worker():
639+
if i % 10 == 9 and is_root:
628640
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
629641
log["sample_per_sec"] = sample_per_sec
630642
print(epoch, i, f'sample_per_sec - {sample_per_sec}')
631643

632644
if i == 201 and args.flops_profiler:
633645
raise StopIteration("Profiler has finished running. Stopping training early.")
634646

635-
if distr_backend.is_root_worker():
647+
if is_root:
636648
wandb.log(log)
637649

638650
if LR_DECAY:
639651
distr_scheduler.step(avg_loss)
640652

641653
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
642654

643-
if distr_backend.is_root_worker():
655+
if is_root:
644656
# save trained model to wandb as an artifact every epoch's end
645-
646-
model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
647-
model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
648-
run.log_artifact(model_artifact)
657+
save_artifact(model_config, DALLE_OUTPUT_FILE_NAME)
649658

650659
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
651-
if distr_backend.is_root_worker():
652-
wandb.save(DALLE_OUTPUT_FILE_NAME)
653-
model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
654-
model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
655-
run.log_artifact(model_artifact)
656660

661+
if is_root:
662+
wandb.save(DALLE_OUTPUT_FILE_NAME)
663+
save_artifact(model_config, DALLE_OUTPUT_FILE_NAME)
657664
wandb.finish()

0 commit comments

Comments
 (0)