Skip to content

Commit e68a30f

Browse files
committed
add a new image_size parameter in train_dalle and generate
VAE models can be use with patches of any size. For example a model trained on 16x16 patches can still be used on 32x32 patches that increase the seq length from 256 to 1024 in dalle
1 parent 01e402e commit e68a30f

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

Diff for: dalle_pytorch/vae.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ def download(url, filename = None, root = CACHE_PATH):
9696
# pretrained Discrete VAE from OpenAI
9797

9898
class OpenAIDiscreteVAE(nn.Module):
99-
def __init__(self):
99+
def __init__(self, image_size=256):
100100
super().__init__()
101101

102102
self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
103103
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
104104

105105
self.num_layers = 3
106-
self.image_size = 256
106+
self.image_size = image_size
107107
self.num_tokens = 8192
108108

109109
@torch.no_grad()
@@ -142,7 +142,7 @@ def instantiate_from_config(config):
142142
return get_obj_from_str(config["target"])(**config.get("params", dict()))
143143

144144
class VQGanVAE(nn.Module):
145-
def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
145+
def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None):
146146
super().__init__()
147147

148148
if vqgan_model_path is None:
@@ -170,7 +170,7 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
170170
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
171171
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
172172
self.num_layers = int(log(f)/log(2))
173-
self.image_size = 256
173+
self.image_size = image_size
174174
self.num_tokens = config.model.params.n_embed
175175
self.is_gumbel = isinstance(self.model, GumbelVQ)
176176

Diff for: generate.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
4444
help='top k filter threshold')
4545

46+
parser.add_argument('--image_size', type = int, default = 256, required = False,
47+
help='image size')
48+
4649
parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
4750
help='output directory')
4851

@@ -81,12 +84,14 @@ def exists(val):
8184

8285
dalle_params.pop('vae', None) # cleanup later
8386

87+
IMAGE_SIZE = args.image_size
88+
8489
if vae_params is not None:
85-
vae = DiscreteVAE(**vae_params)
90+
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
8691
elif not args.taming:
87-
vae = OpenAIDiscreteVAE()
92+
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
8893
else:
89-
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
94+
vae = VQGanVAE(IMAGE_SIZE, args.vqgan_model_path, args.vqgan_config_path)
9095

9196

9297
dalle = DALLE(vae = vae, **dalle_params).cuda()
@@ -95,8 +100,6 @@ def exists(val):
95100

96101
# generate images
97102

98-
image_size = vae.image_size
99-
100103
texts = args.text.split('|')
101104

102105
for text in tqdm(texts):

Diff for: train_dalle.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@
128128

129129
model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')
130130

131+
model_group.add_argument('--image_size', default = 256, type = int, help = 'Image size')
132+
131133
model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')
132134

133135
args = parser.parse_args()
@@ -173,6 +175,7 @@ def cp_path_to_dir(cp_path, tag):
173175
SAVE_EVERY_N_STEPS = args.save_every_n_steps
174176
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints
175177

178+
IMAGE_SIZE = args.image_size
176179
MODEL_DIM = args.dim
177180
TEXT_SEQ_LEN = args.text_seq_len
178181
DEPTH = args.depth
@@ -242,17 +245,16 @@ def cp_path_to_dir(cp_path, tag):
242245
scheduler_state = loaded_obj.get('scheduler_state')
243246

244247
if vae_params is not None:
245-
vae = DiscreteVAE(**vae_params)
248+
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
246249
else:
247250
if args.taming:
248-
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
251+
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
249252
else:
250-
vae = OpenAIDiscreteVAE()
253+
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
251254

252255
dalle_params = dict(
253256
**dalle_params
254257
)
255-
IMAGE_SIZE = vae.image_size
256258
resume_epoch = loaded_obj.get('epoch', 0)
257259
else:
258260
if exists(VAE_PATH):
@@ -268,19 +270,17 @@ def cp_path_to_dir(cp_path, tag):
268270

269271
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
270272

271-
vae = DiscreteVAE(**vae_params)
273+
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
272274
vae.load_state_dict(weights)
273275
else:
274276
if distr_backend.is_root_worker():
275277
print('using pretrained VAE for encoding images to tokens')
276278
vae_params = None
277279

278280
if args.taming:
279-
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
281+
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
280282
else:
281-
vae = OpenAIDiscreteVAE()
282-
283-
IMAGE_SIZE = vae.image_size
283+
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
284284

285285
dalle_params = dict(
286286
num_text_tokens=tokenizer.vocab_size,

0 commit comments

Comments
 (0)