Skip to content

Commit bd512ad

Browse files
committed
fix everything and make sure it runs end to end, document everything in readme for public
1 parent e5e4152 commit bd512ad

File tree

4 files changed

+365
-73
lines changed

4 files changed

+365
-73
lines changed

README.md

+277-7
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,297 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
2222
$ pip install dalle2-pytorch
2323
```
2424

25-
## Usage (work in progress)
26-
27-
<a href="https://github.com/lucidrains/big-sleep">template</a>
25+
## CLI Usage (work in progress)
2826

2927
```bash
3028
$ dream 'sharing a sunset at the summit of mount everest with my dog'
3129
```
3230

3331
Once built, images will be saved to the same directory the command is invoked
3432

35-
## Training (work in progress, will offer both in code and as command-line)
33+
## Training (for deep learning practitioners)
3634

37-
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
35+
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
36+
37+
To train CLIP, you can either use `x-clip` package, or join the LAION discord, where a lot of replication efforts are already underway.
38+
39+
This repository will demonstrate integration with `x-clip` for starters
40+
41+
```python
42+
import torch
43+
from dalle2_pytorch import CLIP
44+
45+
clip = CLIP(
46+
dim_text = 512,
47+
dim_image = 512,
48+
dim_latent = 512,
49+
num_text_tokens = 49408,
50+
text_enc_depth = 1,
51+
text_seq_len = 256,
52+
text_heads = 8,
53+
visual_enc_depth = 1,
54+
visual_image_size = 256,
55+
visual_patch_size = 32,
56+
visual_heads = 8,
57+
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
58+
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
59+
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
60+
use_visual_ssl = True, # whether to do self supervised learning on iages
61+
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
62+
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
63+
text_ssl_loss_weight = 0.05, # weight for text MLM loss
64+
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
65+
).cuda()
66+
67+
# mock data
68+
69+
text = torch.randint(0, 49408, (4, 256)).cuda()
70+
images = torch.randn(4, 3, 256, 256).cuda()
71+
72+
# train
73+
74+
loss = clip(
75+
text,
76+
images,
77+
return_loss = True # needs to be set to True to return contrastive loss
78+
)
79+
80+
loss.backward()
81+
82+
# do the above with as many texts and images as possible in a loop
83+
```
84+
85+
Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above
86+
87+
```python
88+
import torch
89+
from dalle2_pytorch import Unet, Decoder, CLIP
90+
91+
# trained clip from step 1
92+
93+
clip = CLIP(
94+
dim_text = 512,
95+
dim_image = 512,
96+
dim_latent = 512,
97+
num_text_tokens = 49408,
98+
text_enc_depth = 1,
99+
text_seq_len = 256,
100+
text_heads = 8,
101+
visual_enc_depth = 1,
102+
visual_image_size = 256,
103+
visual_patch_size = 32,
104+
visual_heads = 8
105+
).cuda()
106+
107+
# unet for the decoder
108+
109+
unet = Unet(
110+
dim = 128,
111+
image_embed_dim = 512,
112+
time_dim = 128,
113+
channels = 3,
114+
dim_mults=(1, 2, 4, 8)
115+
).cuda()
116+
117+
# decoder, which contains the unet and clip
118+
119+
decoder = Decoder(
120+
net = unet,
121+
clip = clip,
122+
timesteps = 100,
123+
cond_drop_prob = 0.2
124+
).cuda()
125+
126+
# mock images (get a lot of this)
127+
128+
images = torch.randn(4, 3, 256, 256).cuda()
129+
130+
# feed images into decoder
131+
132+
loss = decoder(images)
133+
loss.backward()
134+
135+
# do the above for many many many many steps
136+
# then it will learn to generate images based on the CLIP image embeddings
137+
```
138+
139+
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP fron the first step
140+
141+
```python
142+
import torch
143+
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
144+
145+
clip = CLIP(
146+
dim_text = 512,
147+
dim_image = 512,
148+
dim_latent = 512,
149+
num_text_tokens = 49408,
150+
text_enc_depth = 6,
151+
text_seq_len = 256,
152+
text_heads = 8,
153+
visual_enc_depth = 6,
154+
visual_image_size = 256,
155+
visual_patch_size = 32,
156+
visual_heads = 8,
157+
).cuda()
158+
159+
# setup prior network, which contains an autoregressive transformer
160+
161+
prior_network = DiffusionPriorNetwork(
162+
dim = 512,
163+
num_timesteps = 100,
164+
depth = 6,
165+
dim_head = 64,
166+
heads = 8
167+
).cuda()
168+
169+
# diffusion prior network, which contains the CLIP and network (with transformer) above
170+
171+
diffusion_prior = DiffusionPrior(
172+
net = prior_network,
173+
clip = clip,
174+
timesteps = 100,
175+
cond_drop_prob = 0.2
176+
).cuda()
177+
178+
# mock data
179+
180+
text = torch.randint(0, 49408, (4, 256)).cuda()
181+
images = torch.randn(4, 3, 256, 256).cuda()
182+
183+
# feed text and images into diffusion prior network
38184

39-
Todo
185+
loss = diffusion_prior(text, images)
186+
loss.backward()
187+
188+
# do the above for many many many steps
189+
# now the diffusion prior can generate image embeddings from the text embeddings
190+
```
191+
192+
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
193+
194+
```python
195+
from dalle2_pytorch import DALLE2
196+
197+
dalle2 = DALLE2(
198+
prior = diffusion_prior,
199+
decoder = decoder
200+
)
201+
202+
# send the text as a string if you want to use the simple tokenizer from DALL-E1
203+
# or you can do it as token ids, if you have your own tokenizer
204+
205+
texts = ['glistening morning dew on a flower petal']
206+
images = dalle2(texts) # (1, 3, 256, 256)
207+
```
208+
209+
That's it!
210+
211+
Let's see the whole script below
212+
213+
```python
214+
import torch
215+
from dalle2_pytorch.dalle2_pytorch import DALLE2
216+
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
217+
218+
import torch
219+
220+
clip = CLIP(
221+
dim_text = 512,
222+
dim_image = 512,
223+
dim_latent = 512,
224+
num_text_tokens = 49408,
225+
text_enc_depth = 6,
226+
text_seq_len = 256,
227+
text_heads = 8,
228+
visual_enc_depth = 6,
229+
visual_image_size = 256,
230+
visual_patch_size = 32,
231+
visual_heads = 8
232+
).cuda()
233+
234+
# mock data
235+
236+
text = torch.randint(0, 49408, (4, 256)).cuda()
237+
images = torch.randn(4, 3, 256, 256).cuda()
238+
239+
# train
240+
241+
loss = clip(
242+
text,
243+
images,
244+
return_loss = True
245+
)
246+
247+
loss.backward()
248+
249+
# do above for many steps ...
250+
251+
# prior networks (with transformer)
252+
253+
prior_network = DiffusionPriorNetwork(
254+
dim = 512,
255+
num_timesteps = 100,
256+
depth = 6,
257+
dim_head = 64,
258+
heads = 8
259+
).cuda()
260+
261+
diffusion_prior = DiffusionPrior(
262+
net = prior_network,
263+
clip = clip,
264+
timesteps = 100,
265+
cond_drop_prob = 0.2
266+
).cuda()
267+
268+
loss = diffusion_prior(text, images)
269+
loss.backward()
270+
271+
# do above for many steps ...
272+
273+
# decoder (with unet)
274+
275+
unet = Unet(
276+
dim = 128,
277+
image_embed_dim = 512,
278+
time_dim = 128,
279+
channels = 3,
280+
dim_mults=(1, 2, 4, 8)
281+
).cuda()
282+
283+
decoder = Decoder(
284+
net = unet,
285+
clip = clip,
286+
timesteps = 100,
287+
cond_drop_prob = 0.2
288+
).cuda()
289+
290+
loss = decoder(images)
291+
loss.backward()
292+
293+
# do above for many steps
294+
295+
dalle2 = DALLE2(
296+
prior = diffusion_prior,
297+
decoder = decoder
298+
)
299+
300+
images = dalle2(['cute puppy chasing after a squirrel'])
301+
302+
# save your image
303+
```
304+
305+
Everything in this readme should run without error
306+
307+
## Training CLI (wip)
308+
309+
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
40310

41311
## Todo
42312

43313
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
44314
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
45-
- [ ] make sure it works end to end to produce an output tensor, taking a single gradient step
315+
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
46316
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
47317
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
48318
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)

dalle2_pytorch/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from dalle2_pytorch.dalle2_pytorch import DALLE2
1+
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
2+
from x_clip import CLIP

0 commit comments

Comments
 (0)