Skip to content

Commit f988207

Browse files
committed
hack around some inplace error, also make sure for openai clip text encoding, only tokens after eos_id is masked out
1 parent b207321 commit f988207

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

dalle2_pytorch/dalle2_pytorch.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def __init__(
278278
import clip
279279
openai_clip, preprocess = clip.load(name)
280280
super().__init__(openai_clip)
281+
self.eos_id = 49407 # for handling 0 being also '!'
281282

282283
text_attention_final = self.find_layer('ln_final')
283284
self.handle = text_attention_final.register_forward_hook(self._hook)
@@ -316,7 +317,10 @@ def max_text_len(self):
316317
@torch.no_grad()
317318
def embed_text(self, text):
318319
text = text[..., :self.max_text_len]
319-
text_mask = text != 0
320+
321+
is_eos_id = (text == self.eos_id)
322+
text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
323+
text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
320324
assert not self.cleared
321325

322326
text_embed = self.clip.encode_text(text)
@@ -900,7 +904,7 @@ def forward(
900904
null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
901905

902906
text_encodings = torch.where(
903-
rearrange(mask, 'b n -> b n 1'),
907+
rearrange(mask, 'b n -> b n 1').clone(),
904908
text_encodings,
905909
null_text_embeds
906910
)

dalle2_pytorch/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.23.7'
1+
__version__ = '0.23.8'

0 commit comments

Comments
 (0)