Skip to content

Commit 032bbe5

Browse files
committed
make sure mnumber of resnet blocks are customizable, address difference between t5 encoder and google/t5-v1.1
1 parent d2f6c72 commit 032bbe5

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

imagen_pytorch/imagen_pytorch.py

+23-12
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def __init__(
601601
*,
602602
image_embed_dim = 1024,
603603
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
604+
num_resnet_blocks = 1,
604605
cond_dim = None,
605606
num_image_tokens = 4,
606607
num_time_tokens = 2,
@@ -706,6 +707,7 @@ def __init__(
706707

707708
# resnet block klass
708709

710+
num_resnet_blocks = cast_tuple(num_resnet_blocks, len(in_out))
709711
resnet_groups = cast_tuple(resnet_groups, len(in_out))
710712

711713
assert len(resnet_groups) == len(in_out)
@@ -722,15 +724,15 @@ def __init__(
722724
self.ups = nn.ModuleList([])
723725
num_resolutions = len(in_out)
724726

725-
for ind, ((dim_in, dim_out), groups) in enumerate(zip(in_out, resnet_groups)):
727+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups) in enumerate(zip(in_out, num_resnet_blocks, resnet_groups)):
726728
is_first = ind == 0
727729
is_last = ind >= (num_resolutions - 1)
728730
layer_cond_dim = cond_dim if not is_first else None
729731

730732
self.downs.append(nn.ModuleList([
731733
ResnetBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, groups = groups),
732734
Residual(LinearAttention(dim_out, **attn_kwargs)) if sparse_attn else nn.Identity(),
733-
ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
735+
nn.ModuleList([ResnetBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
734736
downsample_klass(dim_out) if not is_last else nn.Identity()
735737
]))
736738

@@ -740,14 +742,14 @@ def __init__(
740742
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
741743
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
742744

743-
for ind, ((dim_in, dim_out), groups) in enumerate(zip(reversed(in_out[1:]), reversed(resnet_groups))):
745+
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups) in enumerate(zip(reversed(in_out[1:]), reversed(num_resnet_blocks), reversed(resnet_groups))):
744746
is_last = ind >= (num_resolutions - 2)
745747
layer_cond_dim = cond_dim if not is_last else None
746748

747749
self.ups.append(nn.ModuleList([
748750
ResnetBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
749751
Residual(LinearAttention(dim_in, **attn_kwargs)) if sparse_attn else nn.Identity(),
750-
ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
752+
nn.ModuleList([ResnetBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
751753
Upsample(dim_in)
752754
]))
753755

@@ -891,10 +893,13 @@ def forward(
891893

892894
hiddens = []
893895

894-
for block1, sparse_attn, block2, downsample in self.downs:
895-
x = block1(x, c, t)
896+
for init_block, sparse_attn, resnet_blocks, downsample in self.downs:
897+
x = init_block(x, c, t)
896898
x = sparse_attn(x)
897-
x = block2(x, c, t)
899+
900+
for resnet_block in resnet_blocks:
901+
x = resnet_block(x, c, t)
902+
898903
hiddens.append(x)
899904
x = downsample(x)
900905

@@ -905,11 +910,14 @@ def forward(
905910

906911
x = self.mid_block2(x, mid_c, t)
907912

908-
for block1, sparse_attn, block2, upsample in self.ups:
913+
for init_block, sparse_attn, resnet_blocks, upsample in self.ups:
909914
x = torch.cat((x, hiddens.pop()), dim=1)
910-
x = block1(x, c, t)
915+
x = init_block(x, c, t)
911916
x = sparse_attn(x)
912-
x = block2(x, c, t)
917+
918+
for resnet_block in resnet_blocks:
919+
x = resnet_block(x, c, t)
920+
913921
x = upsample(x)
914922

915923
return self.final_conv(x)
@@ -962,7 +970,7 @@ def __init__(
962970
# get text encoder
963971

964972
self.text_encoder_name = text_encoder_name
965-
text_embed_dim = get_encoded_dim(text_encoder_name)
973+
self.text_embed_dim = get_encoded_dim(text_encoder_name)
966974

967975
# construct unets
968976

@@ -977,7 +985,7 @@ def __init__(
977985
one_unet = one_unet.cast_model_parameters(
978986
lowres_cond = not is_first,
979987
cond_on_text = self.condition_on_text,
980-
text_embed_dim = text_embed_dim if self.condition_on_text else None,
988+
text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
981989
channels = self.channels,
982990
channels_out = unet_channels_out
983991
)
@@ -1211,6 +1219,7 @@ def sample(
12111219

12121220
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
12131221
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
1222+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
12141223

12151224
img = None
12161225
is_cuda = next(self.parameters()).is_cuda
@@ -1282,6 +1291,8 @@ def forward(
12821291
assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
12831292
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
12841293

1294+
assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
1295+
12851296
lowres_cond_img = lowres_aug_times = None
12861297
if exists(prev_image_size):
12871298
lowres_cond_img = resize_image_to(image, prev_image_size)

imagen_pytorch/t5.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,16 @@ def t5_encode_text(texts, name = 't5-small'):
104104
attn_mask = encoded.attention_mask.to(device)
105105

106106
t5.eval()
107+
108+
config = T5_CONFIGS[name]
109+
src = config['src']
110+
107111
with torch.no_grad():
108-
output = t5(input_ids = input_ids, attention_mask = attn_mask) # too lazy to figure out how to make it work without decoder inputs
109-
encoded_text = output.last_hidden_state.detach()
112+
if src == 't5':
113+
output = t5(input_ids = input_ids, attention_mask = attn_mask)
114+
encoded_text = output.last_hidden_state.detach()
115+
elif src == 'auto':
116+
output = t5(input_ids = input_ids, attention_mask = attn_mask, decoder_input_ids = input_ids[:, :1])
117+
encoded_text = output.encoder_last_hidden_state.detach()
110118

111119
return encoded_text, attn_mask.bool()

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'imagen-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.21',
6+
version = '0.0.22',
77
license='MIT',
88
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)