Skip to content

Commit 1c25f54

Browse files
committed
start storing version of dalle-pytorch alongside model weights - also throw an error if trying to generate with a dalle-pytorch whose VAE is not the same as the type with which it was trained on
1 parent e1d10b9 commit 1c25f54

File tree

4 files changed

+51
-30
lines changed

4 files changed

+51
-30
lines changed

dalle_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
from dalle_pytorch.dalle_pytorch import DALLE, CLIP, DiscreteVAE
22
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
3+
4+
from pkg_resources import get_distribution
5+
__version__ = get_distribution('dalle_pytorch').version

generate.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,16 @@ def exists(val):
7979
assert dalle_path.exists(), 'trained DALL-E must exist'
8080

8181
load_obj = torch.load(str(dalle_path))
82-
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')
82+
dalle_params, vae_params, weights, vae_class_name, version = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights'), load_obj.pop('vae_class_name', None), load_obj.pop('version', None)
8383

84-
dalle_params.pop('vae', None) # cleanup later
84+
# friendly print
85+
86+
if exists(version):
87+
print(f'Loading a model trained with DALLE-pytorch version {version}')
88+
else:
89+
print('You are loading a model trained on an older version of DALL-E pytorch - it may not be compatible with the most recent version')
90+
91+
# load VAE
8592

8693
if args.taming:
8794
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
@@ -90,6 +97,10 @@ def exists(val):
9097
else:
9198
vae = OpenAIDiscreteVAE()
9299

100+
assert not (exists(vae_class_name) and vae.__class__.__name__ != vae_class_name), f'you trained DALL-E using {vae_class_name} but are trying to generate with {vae.__class__.__name__} - please make sure you are passing in the correct paths and settings for the VAE to use for generation'
101+
102+
# reconstitute DALL-E
103+
93104
dalle = DALLE(vae = vae, **dalle_params).cuda()
94105

95106
dalle.load_state_dict(weights)
@@ -118,6 +129,7 @@ def exists(val):
118129
outputs = torch.cat(outputs)
119130

120131
# save all images
132+
121133
file_name = text
122134
outputs_dir = Path(args.outputs_dir) / file_name.replace(' ', '_')[:(100)]
123135
outputs_dir.mkdir(parents = True, exist_ok = True)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.1.5',
7+
version = '1.1.6',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

train_dalle.py

+33-27
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,8 @@
4545
parser.add_argument('--image_text_folder', type=str, required=True,
4646
help='path to your folder of images and text for learning the DALL-E')
4747

48-
parser.add_argument(
49-
'--wds',
50-
type = str,
51-
default='',
52-
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.'
53-
)
48+
parser.add_argument('--wds', type = str, default='',
49+
help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.')
5450

5551
parser.add_argument('--truncate_captions', dest='truncate_captions', action='store_true',
5652
help='Captions passed in which exceed the max token length will be truncated if this is set.')
@@ -75,7 +71,7 @@
7571

7672

7773
parser.add_argument('--amp', action='store_true',
78-
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')
74+
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')
7975

8076
parser.add_argument('--wandb_name', default='dalle_train_transformer',
8177
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
@@ -144,6 +140,10 @@ def exists(val):
144140
def get_trainable_params(model):
145141
return [params for params in model.parameters() if params.requires_grad]
146142

143+
def get_pkg_version():
144+
from pkg_resources import get_distribution
145+
return get_distribution('dalle_pytorch').version
146+
147147
def cp_path_to_dir(cp_path, tag):
148148
"""Convert a checkpoint path to a directory with `tag` inserted.
149149
If `cp_path` is already a directory, return it unchanged.
@@ -157,6 +157,7 @@ def cp_path_to_dir(cp_path, tag):
157157
return cp_dir
158158

159159
# constants
160+
160161
WEBDATASET_IMAGE_TEXT_COLUMNS = tuple(args.wds.split(','))
161162
ENABLE_WEBDATASET = True if len(WEBDATASET_IMAGE_TEXT_COLUMNS) == 2 else False
162163

@@ -232,6 +233,7 @@ def cp_path_to_dir(cp_path, tag):
232233
tokenizer = ChineseTokenizer()
233234

234235
# reconstitute vae
236+
235237
if RESUME:
236238
dalle_path = Path(DALLE_PATH)
237239
if using_deepspeed:
@@ -249,15 +251,11 @@ def cp_path_to_dir(cp_path, tag):
249251

250252
if vae_params is not None:
251253
vae = DiscreteVAE(**vae_params)
254+
elif args.taming:
255+
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
252256
else:
253-
if args.taming:
254-
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
255-
else:
256-
vae = OpenAIDiscreteVAE()
257+
vae = OpenAIDiscreteVAE()
257258

258-
dalle_params = dict(
259-
**dalle_params
260-
)
261259
IMAGE_SIZE = vae.image_size
262260
resume_epoch = loaded_obj.get('epoch', 0)
263261
else:
@@ -311,7 +309,6 @@ def cp_path_to_dir(cp_path, tag):
311309
if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
312310
vae.enc.blocks.output.conv.use_float16 = True
313311

314-
315312
# helpers
316313

317314
def group_weight(model):
@@ -388,17 +385,20 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
388385
if not ENABLE_WEBDATASET:
389386
print(f'{len(ds)} image-text pairs found for training')
390387

388+
# data sampler
389+
390+
data_sampler = None
391+
391392
if not is_shuffle:
392393
data_sampler = torch.utils.data.distributed.DistributedSampler(
393394
ds,
394395
num_replicas=distr_backend.get_world_size(),
395396
rank=distr_backend.get_rank()
396397
)
397-
else:
398-
data_sampler = None
398+
399+
# WebLoader for WebDataset and DeepSpeed compatibility
399400

400401
if ENABLE_WEBDATASET:
401-
# WebLoader for WebDataset and DeepSpeed compatibility
402402
dl = wds.WebLoader(ds, batch_size=None, shuffle=False, num_workers=4) # optionally add num_workers=2 (n) argument
403403
number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
404404
dl = dl.slice(number_of_batches)
@@ -407,10 +407,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
407407
# Regular DataLoader for image-text-folder datasets
408408
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)
409409

410-
411410
# initialize DALL-E
412411

413412
dalle = DALLE(vae=vae, **dalle_params)
413+
414414
if not using_deepspeed:
415415
if args.fp16:
416416
dalle = dalle.half()
@@ -422,9 +422,14 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
422422
# optimizer
423423

424424
opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
425+
425426
if RESUME and opt_state:
426427
opt.load_state_dict(opt_state)
427428

429+
# scheduler
430+
431+
scheduler = None
432+
428433
if LR_DECAY:
429434
scheduler = ReduceLROnPlateau(
430435
opt,
@@ -437,11 +442,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
437442
)
438443
if RESUME and scheduler_state:
439444
scheduler.load_state_dict(scheduler_state)
440-
else:
441-
scheduler = None
445+
446+
# experiment tracker
442447

443448
if distr_backend.is_root_worker():
444-
# experiment tracker
445449

446450
model_config = dict(
447451
depth=DEPTH,
@@ -503,8 +507,10 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
503507
config_params=deepspeed_config,
504508
)
505509
# Prefer scheduler in `deepspeed_config`.
510+
506511
if LR_DECAY and distr_scheduler is None:
507512
distr_scheduler = scheduler
513+
508514
avoid_model_calls = using_deepspeed and args.fp16
509515

510516
if RESUME and using_deepspeed:
@@ -516,7 +522,10 @@ def save_model(path, epoch=0):
516522
'hparams': dalle_params,
517523
'vae_params': vae_params,
518524
'epoch': epoch,
525+
'version': get_pkg_version(),
526+
'vae_class_name': vae.__class__.__name__
519527
}
528+
520529
if using_deepspeed:
521530
cp_dir = cp_path_to_dir(path, 'ds')
522531

@@ -552,8 +561,9 @@ def save_model(path, epoch=0):
552561
**save_obj,
553562
'weights': dalle.state_dict(),
554563
'opt_state': opt.state_dict(),
564+
'scheduler_state': (scheduler.state_dict() if scheduler else None)
555565
}
556-
save_obj['scheduler_state'] = (scheduler.state_dict() if scheduler else None)
566+
557567
torch.save(save_obj, path)
558568

559569
# training
@@ -611,10 +621,6 @@ def save_model(path, epoch=0):
611621
# CUDA index errors when we don't guard this
612622
image = dalle.generate_images(text[:1], filter_thres=0.9) # topk sampling at 0.9
613623

614-
615-
log = {
616-
**log,
617-
}
618624
if not avoid_model_calls:
619625
log['image'] = wandb.Image(image, caption=decoded_text)
620626

0 commit comments

Comments
 (0)