1818from dalle_pytorch .tokenizer import tokenizer , HugTokenizer , ChineseTokenizer , YttmTokenizer
1919
2020# libraries needed for webdataset support
21+
2122import webdataset as wds
2223from torchvision import transforms as T
2324from PIL import Image
@@ -224,6 +225,8 @@ def cp_path_to_dir(cp_path, tag):
224225using_deepspeed = \
225226 distributed_utils .using_backend (distributed_utils .DeepSpeedBackend )
226227
228+ is_root = distr_backend .is_root_worker ()
229+
227230# tokenizer
228231
229232if exists (args .bpe_path ):
@@ -275,7 +278,7 @@ def cp_path_to_dir(cp_path, tag):
275278 vae = DiscreteVAE (** vae_params )
276279 vae .load_state_dict (weights )
277280 else :
278- if distr_backend . is_root_worker () :
281+ if is_root :
279282 print ('using pretrained VAE for encoding images to tokens' )
280283 vae_params = None
281284
@@ -381,7 +384,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
381384 )
382385 assert len (ds ) > 0 , 'dataset is empty'
383386
384- if distr_backend . is_root_worker () :
387+ if is_root :
385388 if not ENABLE_WEBDATASET :
386389 print (f'{ len (ds )} image-text pairs found for training' )
387390
@@ -445,7 +448,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
445448
446449# experiment tracker
447450
448- if distr_backend . is_root_worker () :
451+ if is_root :
449452
450453 model_config = dict (
451454 depth = DEPTH ,
@@ -529,14 +532,14 @@ def save_model(path, epoch=0):
529532 if using_deepspeed :
530533 cp_dir = cp_path_to_dir (path , 'ds' )
531534
532- if KEEP_N_CHECKPOINTS is not None and distr_backend . is_root_worker () :
535+ if KEEP_N_CHECKPOINTS is not None and is_root :
533536 checkpoints = sorted (glob (str (cp_dir / "global*" )), key = os .path .getmtime , reverse = True )
534537 for checkpoint in checkpoints [KEEP_N_CHECKPOINTS :]:
535538 shutil .rmtree (checkpoint )
536539
537540 distr_dalle .save_checkpoint (cp_dir , client_state = save_obj )
538541
539- if not distr_backend . is_root_worker () :
542+ if not is_root :
540543 return
541544
542545 # Save auxiliary values so we can reuse the standard routine
@@ -554,7 +557,7 @@ def save_model(path, epoch=0):
554557 if deepspeed_config .get ('zero_optimization' , {}).get ('stage' , 0 ) >= 2 : # see https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
555558 return
556559
557- if not distr_backend . is_root_worker () :
560+ if not is_root :
558561 return
559562
560563 save_obj = {
@@ -566,19 +569,29 @@ def save_model(path, epoch=0):
566569
567570 torch .save (save_obj , path )
568571
572+ def save_artifact (model_config , model_path , name = 'trained-dalle' ):
573+ model_artifact = wandb .Artifact (name , type = 'model' , metadata = dict (model_config ))
574+ model_artifact .add_file (model_path )
575+ run .log_artifact (model_artifact )
576+
569577# training
570578
571579# Saves a checkpoint before training begins to fail early when mis-configured.
572580# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
581+
573582save_model (DALLE_OUTPUT_FILE_NAME , epoch = resume_epoch )
583+
574584for epoch in range (resume_epoch , EPOCHS ):
575585 if data_sampler :
576586 data_sampler .set_epoch (epoch )
587+
577588 for i , (text , images ) in enumerate ((dl if ENABLE_WEBDATASET else distr_dl )):
578- if i % 10 == 0 and distr_backend . is_root_worker () :
589+ if i % 10 == 0 and is_root :
579590 t = time .time ()
591+
580592 if args .fp16 :
581593 images = images .half ()
594+
582595 text , images = map (lambda t : t .cuda (), (text , images ))
583596
584597 loss = distr_dalle (text , images , return_loss = True )
@@ -598,7 +611,7 @@ def save_model(path, epoch=0):
598611
599612 log = {}
600613
601- if i % 10 == 0 and distr_backend . is_root_worker () :
614+ if i % 10 == 0 and is_root :
602615 print (epoch , i , f'loss - { avg_loss .item ()} ' )
603616
604617 log = {
@@ -611,47 +624,41 @@ def save_model(path, epoch=0):
611624 if i % SAVE_EVERY_N_STEPS == 0 :
612625 save_model (DALLE_OUTPUT_FILE_NAME , epoch = epoch )
613626
614- if i % 100 == 0 :
615- if distr_backend .is_root_worker ():
616- sample_text = text [:1 ]
617- token_list = sample_text .masked_select (sample_text != 0 ).tolist ()
618- decoded_text = tokenizer .decode (token_list )
627+ if i % 100 == 0 and is_root :
628+ sample_text = text [:1 ]
629+ token_list = sample_text .masked_select (sample_text != 0 ).tolist ()
630+ decoded_text = tokenizer .decode (token_list )
619631
620- if not avoid_model_calls :
621- # CUDA index errors when we don't guard this
622- image = dalle .generate_images (text [:1 ], filter_thres = 0.9 ) # topk sampling at 0.9
632+ if not avoid_model_calls :
633+ # CUDA index errors when we don't guard this
634+ image = dalle .generate_images (text [:1 ], filter_thres = 0.9 ) # topk sampling at 0.9
623635
624- if not avoid_model_calls :
625- log ['image' ] = wandb .Image (image , caption = decoded_text )
636+ if not avoid_model_calls :
637+ log ['image' ] = wandb .Image (image , caption = decoded_text )
626638
627- if i % 10 == 9 and distr_backend . is_root_worker () :
639+ if i % 10 == 9 and is_root :
628640 sample_per_sec = BATCH_SIZE * 10 / (time .time () - t )
629641 log ["sample_per_sec" ] = sample_per_sec
630642 print (epoch , i , f'sample_per_sec - { sample_per_sec } ' )
631643
632644 if i == 201 and args .flops_profiler :
633645 raise StopIteration ("Profiler has finished running. Stopping training early." )
634646
635- if distr_backend . is_root_worker () :
647+ if is_root :
636648 wandb .log (log )
637649
638650 if LR_DECAY :
639651 distr_scheduler .step (avg_loss )
640652
641653 save_model (DALLE_OUTPUT_FILE_NAME , epoch = epoch )
642654
643- if distr_backend . is_root_worker () :
655+ if is_root :
644656 # save trained model to wandb as an artifact every epoch's end
645-
646- model_artifact = wandb .Artifact ('trained-dalle' , type = 'model' , metadata = dict (model_config ))
647- model_artifact .add_file (DALLE_OUTPUT_FILE_NAME )
648- run .log_artifact (model_artifact )
657+ save_artifact (model_config , DALLE_OUTPUT_FILE_NAME )
649658
650659save_model (DALLE_OUTPUT_FILE_NAME , epoch = epoch )
651- if distr_backend .is_root_worker ():
652- wandb .save (DALLE_OUTPUT_FILE_NAME )
653- model_artifact = wandb .Artifact ('trained-dalle' , type = 'model' , metadata = dict (model_config ))
654- model_artifact .add_file (DALLE_OUTPUT_FILE_NAME )
655- run .log_artifact (model_artifact )
656660
661+ if is_root :
662+ wandb .save (DALLE_OUTPUT_FILE_NAME )
663+ save_artifact (model_config , DALLE_OUTPUT_FILE_NAME )
657664 wandb .finish ()
0 commit comments