Skip to content

Commit 3ee3c56

Browse files
committed
add learned padding tokens, same strategy as dalle1, for diffusion prior, and get rid of masking in causal transformer
1 parent cd26c6b commit 3ee3c56

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

dalle2_pytorch/dalle2_pytorch.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -782,17 +782,13 @@ def __init__(
782782
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
783783
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
784784

785-
def forward(
786-
self,
787-
x,
788-
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
789-
):
785+
def forward(self, x):
790786
n, device = x.shape[1], x.device
791787

792788
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
793789

794790
for attn, ff in self.layers:
795-
x = attn(x, mask = mask, attn_bias = attn_bias) + x
791+
x = attn(x, attn_bias = attn_bias) + x
796792
x = ff(x) + x
797793

798794
out = self.norm(x)
@@ -806,7 +802,7 @@ def __init__(
806802
num_time_embeds = 1,
807803
num_image_embeds = 1,
808804
num_text_embeds = 1,
809-
attend_all_text_encodings = True,
805+
max_text_len = 256,
810806
**kwargs
811807
):
812808
super().__init__()
@@ -832,7 +828,10 @@ def __init__(
832828
self.learned_query = nn.Parameter(torch.randn(dim))
833829
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
834830

835-
self.attend_all_text_encodings = attend_all_text_encodings
831+
# dalle1 learned padding strategy
832+
833+
self.max_text_len = max_text_len
834+
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
836835

837836
def forward_with_cond_scale(
838837
self,
@@ -872,11 +871,28 @@ def forward(
872871

873872
if not exists(text_encodings):
874873
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
874+
875+
mask = torch.any(text_encodings != 0., dim = -1)
875876

876-
if self.attend_all_text_encodings:
877-
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
878-
else:
879-
mask = torch.any(text_encodings != 0., dim = -1)
877+
# replace any padding in the text encodings with learned padding tokens unique across position
878+
879+
text_encodings = text_encodings[:, :self.max_text_len]
880+
mask = mask[:, :self.max_text_len]
881+
882+
text_len = text_encodings.shape[-2]
883+
remainder = self.max_text_len - text_len
884+
885+
if remainder > 0:
886+
text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
887+
mask = F.pad(mask, (0, remainder), value = False)
888+
889+
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
890+
891+
text_encodings = torch.where(
892+
rearrange(mask, 'b n -> b n 1'),
893+
text_encodings,
894+
null_text_embeds
895+
)
880896

881897
# classifier free guidance
882898

@@ -910,7 +926,7 @@ def forward(
910926

911927
# attend
912928

913-
tokens = self.causal_transformer(tokens, mask = mask)
929+
tokens = self.causal_transformer(tokens)
914930

915931
# get learned query, which should predict the image embedding (per DDPM timestep)
916932

dalle2_pytorch/train_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def create(self):
129129
class DiffusionPriorNetworkConfig(BaseModel):
130130
dim: int
131131
depth: int
132+
max_text_len: int = None
132133
num_timesteps: int = None
133134
num_time_embeds: int = 1
134135
num_image_embeds: int = 1
135136
num_text_embeds: int = 1
136-
attend_all_text_encodings: bool = True
137137
dim_head: int = 64
138138
heads: int = 8
139139
ff_mult: int = 4

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.22.3'
1+
__version__ = '0.23.1'

0 commit comments

Comments
 (0)