Skip to content

Commit 7e93b9d

Browse files
committed
make sure classifier free guidance condition scaling is exposed on DALLE2 forward function
1 parent 4c827ba commit 7e93b9d

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,10 @@ dalle2 = DALLE2(
296296
decoder = decoder
297297
)
298298

299-
images = dalle2(['cute puppy chasing after a squirrel'])
299+
images = dalle2(
300+
['cute puppy chasing after a squirrel'],
301+
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
302+
)
300303

301304
# save your image
302305
```

dalle2_pytorch/dalle2_pytorch.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,16 @@ def __init__(
246246
def forward_with_cond_scale(
247247
self,
248248
x,
249-
*,
249+
*args,
250250
cond_scale = 1.,
251251
**kwargs
252252
):
253+
logits = self.forward(x, *args, **kwargs)
254+
253255
if cond_scale == 1:
254-
return self.forward(x, **kwargs)
256+
return logits
255257

256-
logits = self.forward(x, **kwargs)
257-
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
258+
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
258259
return null_logits + (logits - null_logits) * cond_scale
259260

260261
def forward(
@@ -635,15 +636,16 @@ def __init__(
635636
def forward_with_cond_scale(
636637
self,
637638
x,
638-
*,
639+
*args,
639640
cond_scale = 1.,
640641
**kwargs
641642
):
643+
logits = self.forward(x, *args, **kwargs)
644+
642645
if cond_scale == 1:
643-
return self.forward(x, **kwargs)
646+
return logits
644647

645-
logits = self.forward(x, **kwargs)
646-
null_logits = self.forward(x, cond_drop_prob = 1., **kwargs)
648+
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
647649
return null_logits + (logits - null_logits) * cond_scale
648650

649651
def forward(
@@ -774,8 +776,8 @@ def q_posterior(self, x_start, x_t, t):
774776
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
775777
return posterior_mean, posterior_variance, posterior_log_variance_clipped
776778

777-
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
778-
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
779+
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
780+
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
779781

780782
if clip_denoised:
781783
x_recon.clamp_(-1., 1.)
@@ -784,31 +786,31 @@ def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
784786
return model_mean, posterior_variance, posterior_log_variance
785787

786788
@torch.no_grad()
787-
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
789+
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
788790
b, *_, device = *x.shape, x.device
789-
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, clip_denoised = clip_denoised)
791+
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
790792
noise = noise_like(x.shape, device, repeat_noise)
791793
# no noise when t == 0
792794
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
793795
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
794796

795797
@torch.no_grad()
796-
def p_sample_loop(self, shape, image_embed):
798+
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
797799
device = self.betas.device
798800

799801
b = shape[0]
800802
img = torch.randn(shape, device=device)
801803

802804
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
803-
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
805+
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
804806
return img
805807

806808
@torch.no_grad()
807-
def sample(self, image_embed):
809+
def sample(self, image_embed, cond_scale = 1.):
808810
batch_size = image_embed.shape[0]
809811
image_size = self.image_size
810812
channels = self.channels
811-
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
813+
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
812814

813815
def q_sample(self, x_start, t, noise=None):
814816
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -869,7 +871,8 @@ def __init__(
869871
@torch.no_grad()
870872
def forward(
871873
self,
872-
text
874+
text,
875+
cond_scale = 1.
873876
):
874877
device = next(self.parameters()).device
875878

@@ -878,5 +881,5 @@ def forward(
878881
text = tokenizer.tokenize(text).to(device)
879882

880883
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
881-
images = self.decoder.sample(image_embed)
884+
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
882885
return images

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'dream = dalle2_pytorch.cli:dream'
1111
],
1212
},
13-
version = '0.0.6',
13+
version = '0.0.7',
1414
license='MIT',
1515
description = 'DALL-E 2',
1616
author = 'Phil Wang',

0 commit comments

Comments
 (0)