Skip to content

Commit 7cf1637

Browse files
committed
bring in the simple tokenizer released by openai, but also plan on leaving room for custom tokenizer with yttm
1 parent 4ff6d02 commit 7cf1637

File tree

5 files changed

+262394
-11
lines changed

5 files changed

+262394
-11
lines changed

MANIFEST.in

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recursive-include dalle2_pytorch *.txt

dalle2_pytorch/dalle2_pytorch.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ def exists(val):
1515
def default(val, d):
1616
return val if exists(val) else d
1717

18+
def eval_decorator(fn):
19+
def inner(model, *args, **kwargs):
20+
was_training = model.training
21+
model.eval()
22+
out = fn(model, *args, **kwargs)
23+
model.train(was_training)
24+
return out
25+
return inner
26+
1827
# for controlling freezing of CLIP
1928

2029
def set_module_requires_grad_(module, requires_grad):
@@ -30,24 +39,61 @@ def unfreeze_all_layers_(module):
3039
# diffusion prior
3140

3241
class DiffusionPrior(nn.Module):
33-
def __init__(self):
42+
def __init__(
43+
self,
44+
*,
45+
clip
46+
):
3447
super().__init__()
35-
def forward(self, x):
36-
return x
48+
assert isinstance(clip, CLIP)
49+
50+
def forward(
51+
self,
52+
*,
53+
text,
54+
image
55+
):
56+
return text
3757

3858
# decoder
3959

4060
class Decoder(nn.Module):
41-
def __init__(self):
61+
def __init__(
62+
self,
63+
*,
64+
clip,
65+
prior
66+
):
4267
super().__init__()
43-
def forward(self, x):
44-
return x
68+
assert isinstance(clip, CLIP)
69+
assert isinstance(prior, DiffusionPrior)
70+
71+
def forward(
72+
self,
73+
*,
74+
image
75+
):
76+
return image
4577

4678
# main class
4779

4880
class DALLE2(nn.Module):
49-
def __init__(self):
81+
def __init__(
82+
self,
83+
*,
84+
clip,
85+
prior,
86+
decoder
87+
):
5088
super().__init__()
51-
52-
def forward(self, x):
53-
return x
89+
assert isinstance(clip), CLIP
90+
assert isinstance(prior), DiffusionPrior
91+
assert isinstance(decoder), Decoder
92+
93+
@torch.no_grad()
94+
def forward(
95+
self,
96+
*,
97+
text
98+
):
99+
return text

0 commit comments

Comments
 (0)