|
10 | 10 | from dalle_pytorch import distributed_utils
|
11 | 11 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
|
12 | 12 | from dalle_pytorch.transformer import Transformer, DivideMax
|
13 |
| -from dalle_pytorch.attention import stable_softmax |
14 | 13 |
|
15 | 14 | # helpers
|
16 | 15 |
|
@@ -48,6 +47,16 @@ def inner(model, *args, **kwargs):
|
48 | 47 |
|
49 | 48 | # sampling helpers
|
50 | 49 |
|
| 50 | +def log(t, eps = 1e-20): |
| 51 | + return torch.log(t + eps) |
| 52 | + |
| 53 | +def gumbel_noise(t): |
| 54 | + noise = torch.zeros_like(t).uniform_(0, 1) |
| 55 | + return -log(-log(noise)) |
| 56 | + |
| 57 | +def gumbel_sample(t, temperature = 1., dim = -1): |
| 58 | + return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) |
| 59 | + |
51 | 60 | def top_k(logits, thres = 0.5):
|
52 | 61 | num_logits = logits.shape[-1]
|
53 | 62 | k = max(int((1 - thres) * num_logits), 1)
|
@@ -442,10 +451,9 @@ def generate_texts(
|
442 | 451 | logits = logits[:, -1, :]
|
443 | 452 |
|
444 | 453 | filtered_logits = top_k(logits, thres = filter_thres)
|
445 |
| - probs = stable_softmax(filtered_logits / temperature, dim = -1) |
446 |
| - sample = torch.multinomial(probs, 1) |
447 |
| - |
448 |
| - text_tokens = torch.cat((text_tokens, sample), dim=-1) |
| 454 | + sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) |
| 455 | + |
| 456 | + text_tokens = torch.cat((text_tokens, sample[:, None]), dim=-1) |
449 | 457 |
|
450 | 458 | padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len))
|
451 | 459 | texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens]
|
@@ -486,14 +494,14 @@ def generate_images(
|
486 | 494 |
|
487 | 495 | text, image = out[:, :text_seq_len], out[:, text_seq_len:]
|
488 | 496 |
|
489 |
| - logits = self(text, image, mask = mask)[:, -1, :] |
| 497 | + logits = self(text, image, mask = mask) |
| 498 | + logits = logits[:, -1, :] |
490 | 499 |
|
491 | 500 | filtered_logits = top_k(logits, thres = filter_thres)
|
492 |
| - probs = stable_softmax(filtered_logits / temperature, dim = -1) |
493 |
| - sample = torch.multinomial(probs, 1) |
| 501 | + sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) |
494 | 502 |
|
495 | 503 | sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
|
496 |
| - out = torch.cat((out, sample), dim=-1) |
| 504 | + out = torch.cat((out, sample[:, None]), dim=-1) |
497 | 505 |
|
498 | 506 | if out.shape[1] <= text_seq_len:
|
499 | 507 | mask = F.pad(mask, (0, 1), value = True)
|
|
0 commit comments