45
45
parser .add_argument ('--image_text_folder' , type = str , required = True ,
46
46
help = 'path to your folder of images and text for learning the DALL-E' )
47
47
48
- parser .add_argument (
49
- '--wds' ,
50
- type = str ,
51
- default = '' ,
52
- help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
53
- )
48
+ parser .add_argument ('--wds' , type = str , default = '' ,
49
+ help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.' )
54
50
55
51
parser .add_argument ('--truncate_captions' , dest = 'truncate_captions' , action = 'store_true' ,
56
52
help = 'Captions passed in which exceed the max token length will be truncated if this is set.' )
75
71
76
72
77
73
parser .add_argument ('--amp' , action = 'store_true' ,
78
- help = 'Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\' t be used in conjunction with deepspeed zero stages 1-3.' )
74
+ help = 'Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\' t be used in conjunction with deepspeed zero stages 1-3.' )
79
75
80
76
parser .add_argument ('--wandb_name' , default = 'dalle_train_transformer' ,
81
77
help = 'Name W&B will use when saving results.\n e.g. `--wandb_name "coco2017-full-sparse"`' )
@@ -144,6 +140,10 @@ def exists(val):
144
140
def get_trainable_params (model ):
145
141
return [params for params in model .parameters () if params .requires_grad ]
146
142
143
+ def get_pkg_version ():
144
+ from pkg_resources import get_distribution
145
+ return get_distribution ('dalle_pytorch' ).version
146
+
147
147
def cp_path_to_dir (cp_path , tag ):
148
148
"""Convert a checkpoint path to a directory with `tag` inserted.
149
149
If `cp_path` is already a directory, return it unchanged.
@@ -157,6 +157,7 @@ def cp_path_to_dir(cp_path, tag):
157
157
return cp_dir
158
158
159
159
# constants
160
+
160
161
WEBDATASET_IMAGE_TEXT_COLUMNS = tuple (args .wds .split (',' ))
161
162
ENABLE_WEBDATASET = True if len (WEBDATASET_IMAGE_TEXT_COLUMNS ) == 2 else False
162
163
@@ -232,6 +233,7 @@ def cp_path_to_dir(cp_path, tag):
232
233
tokenizer = ChineseTokenizer ()
233
234
234
235
# reconstitute vae
236
+
235
237
if RESUME :
236
238
dalle_path = Path (DALLE_PATH )
237
239
if using_deepspeed :
@@ -249,15 +251,11 @@ def cp_path_to_dir(cp_path, tag):
249
251
250
252
if vae_params is not None :
251
253
vae = DiscreteVAE (** vae_params )
254
+ elif args .taming :
255
+ vae = VQGanVAE (VQGAN_MODEL_PATH , VQGAN_CONFIG_PATH )
252
256
else :
253
- if args .taming :
254
- vae = VQGanVAE (VQGAN_MODEL_PATH , VQGAN_CONFIG_PATH )
255
- else :
256
- vae = OpenAIDiscreteVAE ()
257
+ vae = OpenAIDiscreteVAE ()
257
258
258
- dalle_params = dict (
259
- ** dalle_params
260
- )
261
259
IMAGE_SIZE = vae .image_size
262
260
resume_epoch = loaded_obj .get ('epoch' , 0 )
263
261
else :
@@ -311,7 +309,6 @@ def cp_path_to_dir(cp_path, tag):
311
309
if isinstance (vae , OpenAIDiscreteVAE ) and args .fp16 :
312
310
vae .enc .blocks .output .conv .use_float16 = True
313
311
314
-
315
312
# helpers
316
313
317
314
def group_weight (model ):
@@ -388,17 +385,20 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
388
385
if not ENABLE_WEBDATASET :
389
386
print (f'{ len (ds )} image-text pairs found for training' )
390
387
388
+ # data sampler
389
+
390
+ data_sampler = None
391
+
391
392
if not is_shuffle :
392
393
data_sampler = torch .utils .data .distributed .DistributedSampler (
393
394
ds ,
394
395
num_replicas = distr_backend .get_world_size (),
395
396
rank = distr_backend .get_rank ()
396
397
)
397
- else :
398
- data_sampler = None
398
+
399
+ # WebLoader for WebDataset and DeepSpeed compatibility
399
400
400
401
if ENABLE_WEBDATASET :
401
- # WebLoader for WebDataset and DeepSpeed compatibility
402
402
dl = wds .WebLoader (ds , batch_size = None , shuffle = False , num_workers = 4 ) # optionally add num_workers=2 (n) argument
403
403
number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend .get_world_size ())
404
404
dl = dl .slice (number_of_batches )
@@ -407,10 +407,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
407
407
# Regular DataLoader for image-text-folder datasets
408
408
dl = DataLoader (ds , batch_size = BATCH_SIZE , shuffle = is_shuffle , drop_last = True , sampler = data_sampler )
409
409
410
-
411
410
# initialize DALL-E
412
411
413
412
dalle = DALLE (vae = vae , ** dalle_params )
413
+
414
414
if not using_deepspeed :
415
415
if args .fp16 :
416
416
dalle = dalle .half ()
@@ -422,9 +422,14 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
422
422
# optimizer
423
423
424
424
opt = Adam (get_trainable_params (dalle ), lr = LEARNING_RATE )
425
+
425
426
if RESUME and opt_state :
426
427
opt .load_state_dict (opt_state )
427
428
429
+ # scheduler
430
+
431
+ scheduler = None
432
+
428
433
if LR_DECAY :
429
434
scheduler = ReduceLROnPlateau (
430
435
opt ,
@@ -437,11 +442,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
437
442
)
438
443
if RESUME and scheduler_state :
439
444
scheduler .load_state_dict (scheduler_state )
440
- else :
441
- scheduler = None
445
+
446
+ # experiment tracker
442
447
443
448
if distr_backend .is_root_worker ():
444
- # experiment tracker
445
449
446
450
model_config = dict (
447
451
depth = DEPTH ,
@@ -503,8 +507,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
503
507
config_params = deepspeed_config ,
504
508
)
505
509
# Prefer scheduler in `deepspeed_config`.
510
+
506
511
if LR_DECAY and distr_scheduler is None :
507
512
distr_scheduler = scheduler
513
+
508
514
avoid_model_calls = using_deepspeed and args .fp16
509
515
510
516
if RESUME and using_deepspeed :
@@ -516,7 +522,10 @@ def save_model(path, epoch=0):
516
522
'hparams' : dalle_params ,
517
523
'vae_params' : vae_params ,
518
524
'epoch' : epoch ,
525
+ 'version' : get_pkg_version (),
526
+ 'vae_class_name' : vae .__class__ .__name__
519
527
}
528
+
520
529
if using_deepspeed :
521
530
cp_dir = cp_path_to_dir (path , 'ds' )
522
531
@@ -552,8 +561,9 @@ def save_model(path, epoch=0):
552
561
** save_obj ,
553
562
'weights' : dalle .state_dict (),
554
563
'opt_state' : opt .state_dict (),
564
+ 'scheduler_state' : (scheduler .state_dict () if scheduler else None )
555
565
}
556
- save_obj [ 'scheduler_state' ] = ( scheduler . state_dict () if scheduler else None )
566
+
557
567
torch .save (save_obj , path )
558
568
559
569
# training
@@ -611,10 +621,6 @@ def save_model(path, epoch=0):
611
621
# CUDA index errors when we don't guard this
612
622
image = dalle .generate_images (text [:1 ], filter_thres = 0.9 ) # topk sampling at 0.9
613
623
614
-
615
- log = {
616
- ** log ,
617
- }
618
624
if not avoid_model_calls :
619
625
log ['image' ] = wandb .Image (image , caption = decoded_text )
620
626
0 commit comments