Skip to content

Commit 07f6f3f

Browse files
committed
remove softmax and torch multinomial when sampling
1 parent 2783801 commit 07f6f3f

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

dalle_pytorch/dalle_pytorch.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from dalle_pytorch import distributed_utils
1111
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
1212
from dalle_pytorch.transformer import Transformer, DivideMax
13-
from dalle_pytorch.attention import stable_softmax
1413

1514
# helpers
1615

@@ -48,6 +47,16 @@ def inner(model, *args, **kwargs):
4847

4948
# sampling helpers
5049

50+
def log(t, eps = 1e-20):
51+
return torch.log(t + eps)
52+
53+
def gumbel_noise(t):
54+
noise = torch.zeros_like(t).uniform_(0, 1)
55+
return -log(-log(noise))
56+
57+
def gumbel_sample(t, temperature = 1., dim = -1):
58+
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
59+
5160
def top_k(logits, thres = 0.5):
5261
num_logits = logits.shape[-1]
5362
k = max(int((1 - thres) * num_logits), 1)
@@ -442,10 +451,9 @@ def generate_texts(
442451
logits = logits[:, -1, :]
443452

444453
filtered_logits = top_k(logits, thres = filter_thres)
445-
probs = stable_softmax(filtered_logits / temperature, dim = -1)
446-
sample = torch.multinomial(probs, 1)
447-
448-
text_tokens = torch.cat((text_tokens, sample), dim=-1)
454+
sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
455+
456+
text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1)
449457

450458
padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
451459
texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
@@ -486,14 +494,14 @@ def generate_images(
486494

487495
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
488496

489-
logits = self(text, image, mask = mask)[:, -1, :]
497+
logits = self(text, image, mask = mask)
498+
logits = logits[:, -1, :]
490499

491500
filtered_logits = top_k(logits, thres = filter_thres)
492-
probs = stable_softmax(filtered_logits / temperature, dim = -1)
493-
sample = torch.multinomial(probs, 1)
501+
sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
494502

495503
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
496-
out = torch.cat((out, sample), dim=-1)
504+
out = torch.cat((out, sample[:, None]), dim=-1)
497505

498506
if out.shape[1] <= text_seq_len:
499507
mask = F.pad(mask, (0, 1), value = True)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.1.7',
7+
version = '1.1.8',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)