Skip to content

Commit 459c46a

Browse files
committed
add super-conditioning feature, idea originated from @crowsonkb
1 parent 4511d2d commit 459c46a

File tree

3 files changed

+96
-4
lines changed

3 files changed

+96
-4
lines changed

README.md

+74
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,63 @@ vae = VQGanVAE()
200200

201201
The default VQGan is the codebook size 1024 one trained on imagenet. If you wish to use a different one, you can use the `vqgan_model_path` and `vqgan_config_path` to pass the .ckpt file and the .yaml file. These options can be used both in train-dalle script or as argument of VQGanVAE class. Other pretrained VQGAN can be found in [taming transformers readme](https://github.com/CompVis/taming-transformers#overview-of-pretrained-models). If you want to train a custom one you can [follow this guide](https://github.com/CompVis/taming-transformers/pull/54)
202202

203+
204+
## Adjust text conditioning strength
205+
206+
Recently there has surfaced a <a href="https://openreview.net/forum?id=qw8AKxfYbI">new technique</a> for guiding diffusion models without a classifier. The gist of the technique involves randomly dropping out the text condition during training, and at inference time, deriving the rough direction from unconditional to conditional distributions.
207+
208+
<a href="https://github.com/crowsonkb">Katherine Crowson</a> outlined in a <a href="https://twitter.com/RiversHaveWings/status/1478093658716966912">tweet</a> how this could work for autoregressive attention models. I have decided to include her idea in this repository for further exploration. One only has to account for two extra keyword arguments on training (`null_cond_prob`) and generation (`cond_scale`).
209+
210+
```python
211+
import torch
212+
from dalle_pytorch import DiscreteVAE, DALLE
213+
214+
vae = DiscreteVAE(
215+
image_size = 256,
216+
num_layers = 3,
217+
num_tokens = 8192,
218+
codebook_dim = 1024,
219+
hidden_dim = 64,
220+
num_resnet_blocks = 1,
221+
temperature = 0.9
222+
)
223+
224+
dalle = DALLE(
225+
dim = 1024,
226+
vae = vae,
227+
num_text_tokens = 10000,
228+
text_seq_len = 256,
229+
depth = 12,
230+
heads = 16,
231+
dim_head = 64,
232+
attn_dropout = 0.1,
233+
ff_dropout = 0.1
234+
)
235+
236+
text = torch.randint(0, 10000, (4, 256))
237+
images = torch.randn(4, 3, 256, 256)
238+
239+
loss = dalle(
240+
text,
241+
images,
242+
return_loss = True,
243+
null_cond_prob = 0.2 # firstly, set this to the probability of dropping out the condition, 20% is recommended as a default
244+
)
245+
246+
loss.backward()
247+
248+
# do the above for a long time with a lot of data ... then
249+
250+
images = dalle.generate_images(
251+
text,
252+
cond_scale = 3. # secondly, set this to a value greater than 1 to increase the conditioning beyond average
253+
)
254+
255+
images.shape # (4, 3, 256, 256)
256+
```
257+
258+
That's it!
259+
203260
## Ranking the generations
204261

205262
Train CLIP
@@ -673,4 +730,21 @@ $ python generate.py --chinese --text '追老鼠的猫'
673730
}
674731
```
675732

733+
```bibtex
734+
@inproceedings{ho2021classifierfree,
735+
title = {Classifier-Free Diffusion Guidance},
736+
author = {Jonathan Ho and Tim Salimans},
737+
booktitle = {NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications},
738+
year = {2021},
739+
url = {https://openreview.net/forum?id=qw8AKxfYbI}
740+
}
741+
```
742+
743+
```bibtex
744+
@misc{crowson2022,
745+
author = {Katherine Crowson},
746+
url = {https://twitter.com/RiversHaveWings/status/1478093658716966912}
747+
}
748+
```
749+
676750
*Those who do not want to imitate anything, produce nothing.* - Dali

dalle_pytorch/dalle_pytorch.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def masked_mean(t, mask, dim = 1):
3232
t = t.masked_fill(~mask[:, :, None], 0.)
3333
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
3434

35+
def prob_mask_like(shape, prob, device):
36+
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
37+
3538
def set_requires_grad(model, value):
3639
for param in model.parameters():
3740
param.requires_grad = value
@@ -469,7 +472,8 @@ def generate_images(
469472
filter_thres = 0.5,
470473
temperature = 1.,
471474
img = None,
472-
num_init_img_tokens = None
475+
num_init_img_tokens = None,
476+
cond_scale = 1.
473477
):
474478
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
475479
total_len = text_seq_len + image_seq_len
@@ -494,6 +498,13 @@ def generate_images(
494498
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
495499

496500
logits = self(text, image)
501+
502+
if cond_scale != 1:
503+
# discovery by Katherine Crowson
504+
# https://twitter.com/RiversHaveWings/status/1478093658716966912
505+
null_cond_logits = self(text, image, null_cond_prob = 1.)
506+
logits = null_cond_logits + (logits - null_cond_logits) * cond_scale
507+
497508
logits = logits[:, -1, :]
498509

499510
filtered_logits = top_k(logits, thres = filter_thres)
@@ -517,10 +528,17 @@ def forward(
517528
self,
518529
text,
519530
image = None,
520-
return_loss = False
531+
return_loss = False,
532+
null_cond_prob = 0.
521533
):
522534
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
523-
device, total_seq_len = text.device, self.total_seq_len
535+
batch, device, total_seq_len = text.shape[0], text.device, self.total_seq_len
536+
537+
# randomly remove text condition with <null_cond_prob> probability
538+
539+
if null_cond_prob > 0:
540+
null_mask = prob_mask_like((batch,), null_cond_prob, device = device)
541+
text *= rearrange(~null_mask, 'b -> b 1')
524542

525543
# make sure padding in text tokens get unique padding token id
526544

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.2.0',
7+
version = '1.2.1',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)