Skip to content

Commit 322496f

Browse files
committed
allow for use of larger t5, and extensible to any text encoder model
1 parent 156c496 commit 322496f

File tree

4 files changed

+73
-29
lines changed

4 files changed

+73
-29
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ unet2 = Unet(
3939
# imagen, which contains the unets above (base unet and super resoluting ones)
4040

4141
imagen = Imagen(
42-
unet = (unet1, unet2),
42+
unets = (unet1, unet2),
4343
image_sizes = (64, 256),
4444
timesteps = 100,
4545
cond_drop_prob = 0.5
@@ -73,7 +73,7 @@ images.shape # (3, 3, 256, 256)
7373
- [x] use huggingface transformers for T5-small text embeddings
7474
- [x] add dynamic thresholding
7575
- [x] add dynamic thresholding DALLE2 and video-diffusion repository as well
76-
- [ ] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
76+
- [x] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
7777
- [ ] separate unet into base unet and SR3 unet
7878
- [ ] build whatever efficient unet they came up with
7979
- [ ] add the noise level conditioning with the pseudocode in appendix, and figure out what is this sweep they do at inference time

imagen_pytorch/imagen_pytorch.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from resize_right import resize
2424

25-
from imagen_pytorch.t5 import t5_encode_text, T5_SMALL_EMBED_DIM
25+
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim
2626

2727
# constants
2828

@@ -233,7 +233,7 @@ def __init__(self, *, beta_schedule, timesteps, loss_type):
233233

234234
# register buffer helper function to cast double back to float
235235

236-
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
236+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32), persistent = False)
237237

238238
register_buffer('betas', betas)
239239
register_buffer('alphas_cumprod', alphas_cumprod)
@@ -691,7 +691,7 @@ def __init__(
691691
dim,
692692
*,
693693
image_embed_dim = None,
694-
text_embed_dim = T5_SMALL_EMBED_DIM,
694+
text_embed_dim = 512,
695695
cond_dim = None,
696696
num_image_tokens = 4,
697697
num_time_tokens = 2,
@@ -842,18 +842,21 @@ def cast_model_parameters(
842842
self,
843843
*,
844844
lowres_cond,
845+
text_embed_dim,
845846
channels,
846847
channels_out,
847848
cond_on_text
848849
):
849850
if lowres_cond == self.lowres_cond and \
850851
channels == self.channels and \
851852
cond_on_text == self.cond_on_text and \
853+
text_embed_dim == self._locals['text_embed_dim'] and \
852854
channels_out == self.channels_out:
853855
return self
854856

855857
updated_kwargs = dict(
856858
lowres_cond = lowres_cond,
859+
text_embed_dim = text_embed_dim,
857860
channels = channels,
858861
channels_out = channels_out,
859862
cond_on_text = cond_on_text
@@ -1020,8 +1023,9 @@ def forward(
10201023
class Imagen(BaseGaussianDiffusion):
10211024
def __init__(
10221025
self,
1023-
unet,
1026+
unets,
10241027
*,
1028+
text_encoder_name = 't5-small',
10251029
image_size = None,
10261030
channels = 3,
10271031
timesteps = 1000,
@@ -1068,14 +1072,18 @@ def __init__(
10681072
# automatically take care of ensuring that first unet is unconditional
10691073
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
10701074

1071-
unets = cast_tuple(unet)
1075+
unets = cast_tuple(unets)
10721076

10731077
# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
10741078

10751079
learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
10761080
self.learned_variance = learned_variance
10771081
self.vb_loss_weight = vb_loss_weight
10781082

1083+
# get text encoder
1084+
1085+
text_embed_dim = get_encoded_dim(text_encoder_name)
1086+
10791087
# construct unets
10801088

10811089
self.unets = nn.ModuleList([])
@@ -1089,6 +1097,7 @@ def __init__(
10891097
one_unet = one_unet.cast_model_parameters(
10901098
lowres_cond = not is_first,
10911099
cond_on_text = one_unet.cond_on_text and not unconditional,
1100+
text_embed_dim = text_embed_dim,
10921101
channels = self.channels,
10931102
channels_out = unet_channels_out
10941103
)

imagen_pytorch/t5.py

+56-21
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,73 @@
44
def exists(val):
55
return val is not None
66

7+
# config
8+
9+
MAX_LENGTH = 256
10+
11+
T5_CONFIGS = {
12+
't5-small': {
13+
'dim': 512
14+
},
15+
't5-large': {
16+
'dim': 1024
17+
}
18+
}
19+
720
# singleton globals
821

9-
MODEL = None
10-
TOKENIZER = None
11-
T5_SMALL_EMBED_DIM = 512
22+
def get_tokenizer(name):
23+
assert name in T5_CONFIGS
24+
tokenizer = T5Tokenizer.from_pretrained("t5-small")
25+
return tokenizer
26+
27+
def get_model(name):
28+
assert name in T5_CONFIGS
29+
model = T5ForConditionalGeneration.from_pretrained("t5-small")
30+
return model
31+
32+
def get_model_and_tokenizer(name):
33+
global T5_CONFIGS
34+
assert name in T5_CONFIGS, f'{name} model is not found in the configuration'
35+
config = T5_CONFIGS[name]
1236

13-
def get_tokenizer():
14-
global TOKENIZER
15-
if not exists(TOKENIZER):
16-
TOKENIZER = T5Tokenizer.from_pretrained("t5-small")
17-
return TOKENIZER
37+
if not 'model' in config:
38+
model = get_model(name)
39+
config['model'] = model
1840

19-
def get_t5():
20-
global MODEL
21-
if not exists(MODEL):
22-
MODEL = T5ForConditionalGeneration.from_pretrained("t5-small")
23-
if torch.cuda.is_available():
24-
MODEL = MODEL.cuda()
41+
if not 'tokenizer' in config:
42+
tokenizer = get_tokenizer(name)
43+
config['tokenizer'] = tokenizer
2544

26-
return MODEL
45+
return config['model'], config['tokenizer']
46+
47+
def get_encoded_dim(name):
48+
assert name in T5_CONFIGS, f'{name} model is not found in configuration'
49+
return T5_CONFIGS[name]['dim']
2750

2851
# encoding text
2952

30-
def t5_encode_text(texts):
31-
t5 = get_t5()
32-
tokenizer = get_tokenizer()
53+
def t5_encode_text(texts, name = 't5-small'):
54+
t5, tokenizer = get_model_and_tokenizer(name)
55+
56+
if torch.cuda.is_available():
57+
t5 = t5.cuda()
58+
59+
device = next(t5.parameters()).device
60+
61+
encoded = tokenizer.batch_encode_plus(
62+
texts,
63+
return_tensors = "pt",
64+
padding = 'longest',
65+
max_length = MAX_LENGTH,
66+
truncation = True
67+
)
3368

34-
input_ids = tokenizer.batch_encode_plus(texts, return_tensors = "pt", padding = True, truncation = True).input_ids
35-
input_ids = input_ids.to(next(t5.parameters()).device)
69+
input_ids = encoded.input_ids.to(device)
70+
attn_mask = encoded.attention_mask.to(device)
3671

3772
t5.eval()
3873
with torch.no_grad():
39-
output = t5(input_ids = input_ids, decoder_input_ids = input_ids[:, :1]) # too lazy to figure out how to make it work without decoder inputs
74+
output = t5(input_ids = input_ids, attention_mask = attn_mask, decoder_input_ids = input_ids[:, :1]) # too lazy to figure out how to make it work without decoder inputs
4075

4176
return output.encoder_last_hidden_state

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'imagen-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.5',
6+
version = '0.0.6',
77
license='MIT',
88
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)