Skip to content

Commit 35e9776

Browse files
committed
v0.2.0, MinDalleTorch -> MinDalle, breaking change
1 parent 2080e59 commit 35e9776

10 files changed

+43
-45
lines changed

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ $ pip install min-dalle
2222

2323
### Python
2424

25-
To load a model once and generate multiple times, first initialize `MinDalleTorch`.
25+
To load a model once and generate multiple times, first initialize `MinDalle`.
2626

2727
```python
28-
from min_dalle import MinDalleTorch
28+
from min_dalle import MinDalle
2929

30-
model = MinDalleTorch(
30+
model = MinDalle(
3131
is_mega=True,
3232
is_reusable=True,
3333
models_root='./pretrained'

image_from_text.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import os
33
from PIL import Image
44

5-
from min_dalle import MinDalleTorch
5+
from min_dalle import MinDalle
6+
67

78
parser = argparse.ArgumentParser()
89
parser.add_argument('--mega', action='store_true')
910
parser.add_argument('--no-mega', dest='mega', action='store_false')
1011
parser.set_defaults(mega=False)
1112
parser.add_argument('--text', type=str, default='alien life')
12-
parser.add_argument('--seed', type=int, default=7)
13+
parser.add_argument('--seed', type=int, default=-1)
1314
parser.add_argument('--image_path', type=str, default='generated')
1415
parser.add_argument('--token_count', type=int, default=256) # for debugging
1516

@@ -39,7 +40,7 @@ def generate_image(
3940
image_path: str,
4041
token_count: int
4142
):
42-
model = MinDalleTorch(
43+
model = MinDalle(
4344
is_mega=is_mega,
4445
models_root='pretrained',
4546
is_reusable=False,

min_dalle.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@
7777
}
7878
],
7979
"source": [
80-
"from min_dalle import MinDalleTorch\n",
80+
"from min_dalle import MinDalle\n",
8181
"\n",
82-
"model = MinDalleTorch(is_mega=True, is_reusable=True)"
82+
"model = MinDalle(is_mega=True, is_reusable=True)"
8383
]
8484
},
8585
{

min_dalle/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .min_dalle_torch import MinDalleTorch
1+
from .min_dalle import MinDalle

min_dalle/min_dalle_torch.py renamed to min_dalle/min_dalle.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
from PIL import Image
3-
from typing import Dict
43
import numpy
54
from torch import LongTensor
65
import torch
@@ -10,16 +9,13 @@
109
torch.set_grad_enabled(False)
1110
torch.set_num_threads(os.cpu_count())
1211

12+
from .text_tokenizer import TextTokenizer
13+
from .models import DalleBartEncoder, DalleBartDecoder, VQGanDetokenizer
14+
1315
MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
1416

15-
from .text_tokenizer import TextTokenizer
16-
from .models import (
17-
DalleBartEncoderTorch,
18-
DalleBartDecoderTorch,
19-
VQGanDetokenizer
20-
)
2117

22-
class MinDalleTorch:
18+
class MinDalle:
2319
def __init__(
2420
self,
2521
is_mega: bool,
@@ -104,7 +100,7 @@ def init_encoder(self):
104100
is_downloaded = os.path.exists(self.encoder_params_path)
105101
if not is_downloaded: self.download_encoder()
106102
print("initializing DalleBartEncoderTorch")
107-
self.encoder = DalleBartEncoderTorch(
103+
self.encoder = DalleBartEncoder(
108104
attention_head_count = self.attention_head_count,
109105
embed_count = self.embed_count,
110106
glu_embed_count = self.glu_embed_count,
@@ -122,7 +118,7 @@ def init_decoder(self):
122118
is_downloaded = os.path.exists(self.decoder_params_path)
123119
if not is_downloaded: self.download_decoder()
124120
print("initializing DalleBartDecoderTorch")
125-
self.decoder = DalleBartDecoderTorch(
121+
self.decoder = DalleBartDecoder(
126122
sample_token_count = self.sample_token_count,
127123
image_token_count = self.image_token_count,
128124
image_vocab_count = self.image_vocab_count,

min_dalle/models/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .dalle_bart_encoder_torch import DalleBartEncoderTorch
2-
from .dalle_bart_decoder_torch import DalleBartDecoderTorch
1+
from .dalle_bart_encoder import DalleBartEncoder
2+
from .dalle_bart_decoder import DalleBartDecoder
33
from .vqgan_detokenizer import VQGanDetokenizer

min_dalle/models/dalle_bart_decoder_torch.py renamed to min_dalle/models/dalle_bart_decoder.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from torch import LongTensor, nn, FloatTensor, BoolTensor
44
torch.set_grad_enabled(False)
55

6-
from .dalle_bart_encoder_torch import GLUTorch, AttentionTorch
6+
from .dalle_bart_encoder import GLU, AttentionBase
77

88

9-
class DecoderCrossAttentionTorch(AttentionTorch):
9+
class DecoderCrossAttention(AttentionBase):
1010
def forward(
1111
self,
1212
decoder_state: FloatTensor,
@@ -19,7 +19,7 @@ def forward(
1919
return super().forward(keys, values, queries, attention_mask)
2020

2121

22-
class DecoderSelfAttentionTorch(AttentionTorch):
22+
class DecoderSelfAttention(AttentionBase):
2323
def forward(
2424
self,
2525
decoder_state: FloatTensor,
@@ -42,7 +42,7 @@ def forward(
4242
return decoder_state, attention_state
4343

4444

45-
class DecoderLayerTorch(nn.Module):
45+
class DecoderLayer(nn.Module):
4646
def __init__(
4747
self,
4848
image_token_count: int,
@@ -53,12 +53,12 @@ def __init__(
5353
super().__init__()
5454
self.image_token_count = image_token_count
5555
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
56-
self.self_attn = DecoderSelfAttentionTorch(head_count, embed_count)
56+
self.self_attn = DecoderSelfAttention(head_count, embed_count)
5757
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
5858
self.pre_encoder_attn_layer_norm = nn.LayerNorm(embed_count)
59-
self.encoder_attn = DecoderCrossAttentionTorch(head_count, embed_count)
59+
self.encoder_attn = DecoderCrossAttention(head_count, embed_count)
6060
self.encoder_attn_layer_norm = nn.LayerNorm(embed_count)
61-
self.glu = GLUTorch(embed_count, glu_embed_count)
61+
self.glu = GLU(embed_count, glu_embed_count)
6262

6363
self.token_indices = torch.arange(self.image_token_count)
6464
if torch.cuda.is_available():
@@ -106,7 +106,7 @@ def forward(
106106
return decoder_state, attention_state
107107

108108

109-
class DalleBartDecoderTorch(nn.Module):
109+
class DalleBartDecoder(nn.Module):
110110
def __init__(
111111
self,
112112
image_vocab_count: int,
@@ -126,8 +126,8 @@ def __init__(
126126
self.image_token_count = image_token_count
127127
self.embed_tokens = nn.Embedding(image_vocab_count + 1, embed_count)
128128
self.embed_positions = nn.Embedding(image_token_count, embed_count)
129-
self.layers: List[DecoderLayerTorch] = nn.ModuleList([
130-
DecoderLayerTorch(
129+
self.layers: List[DecoderLayer] = nn.ModuleList([
130+
DecoderLayer(
131131
image_token_count,
132132
attention_head_count,
133133
embed_count,

min_dalle/models/dalle_bart_encoder_torch.py renamed to min_dalle/models/dalle_bart_encoder.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
torch.set_grad_enabled(False)
55

66

7-
class GLUTorch(nn.Module):
7+
class GLU(nn.Module):
88
def __init__(self, count_in_out, count_middle):
99
super().__init__()
1010
self.gelu = nn.GELU()
@@ -24,7 +24,7 @@ def forward(self, z: FloatTensor) -> FloatTensor:
2424
return z
2525

2626

27-
class AttentionTorch(nn.Module):
27+
class AttentionBase(nn.Module):
2828
def __init__(self, head_count: int, embed_count: int):
2929
super().__init__()
3030
self.head_count = head_count
@@ -72,7 +72,7 @@ def forward(
7272
return attention_output
7373

7474

75-
class EncoderSelfAttentionTorch(AttentionTorch):
75+
class EncoderSelfAttention(AttentionBase):
7676
def forward(
7777
self,
7878
encoder_state: FloatTensor,
@@ -84,13 +84,13 @@ def forward(
8484
return super().forward(keys, values, queries, attention_mask)
8585

8686

87-
class EncoderLayerTorch(nn.Module):
87+
class EncoderLayer(nn.Module):
8888
def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
8989
super().__init__()
9090
self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
91-
self.self_attn = EncoderSelfAttentionTorch(head_count, embed_count)
91+
self.self_attn = EncoderSelfAttention(head_count, embed_count)
9292
self.self_attn_layer_norm = nn.LayerNorm(embed_count)
93-
self.glu = GLUTorch(embed_count, glu_embed_count)
93+
self.glu = GLU(embed_count, glu_embed_count)
9494

9595
def forward(
9696
self,
@@ -108,7 +108,7 @@ def forward(
108108
return encoder_state
109109

110110

111-
class DalleBartEncoderTorch(nn.Module):
111+
class DalleBartEncoder(nn.Module):
112112
def __init__(
113113
self,
114114
layer_count: int,
@@ -121,8 +121,8 @@ def __init__(
121121
super().__init__()
122122
self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
123123
self.embed_positions = nn.Embedding(text_token_count, embed_count)
124-
self.layers: List[EncoderLayerTorch] = nn.ModuleList([
125-
EncoderLayerTorch(
124+
self.layers: List[EncoderLayer] = nn.ModuleList([
125+
EncoderLayer(
126126
embed_count = embed_count,
127127
head_count = attention_head_count,
128128
glu_embed_count = glu_embed_count

replicate/predict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import tempfile
22
from cog import BasePredictor, Path, Input
33

4-
from min_dalle.min_dalle_torch import MinDalleTorch
4+
from min_dalle import MinDalle
55

66
class Predictor(BasePredictor):
77
def setup(self):
8-
self.model = MinDalleTorch(is_mega=True)
8+
self.model = MinDalle(is_mega=True)
99

1010
def predict(
1111
self,

setup.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
setuptools.setup(
44
name='min-dalle',
55
description = 'min(DALL·E)',
6-
version='0.1.4',
6+
version='0.2.0',
77
author='Brett Kuprel',
8-
author_email = '[email protected]',
8+
author_email='[email protected]',
99
packages=[
1010
'min_dalle',
1111
'min_dalle.models'
@@ -18,6 +18,7 @@
1818
keywords = [
1919
'artificial intelligence',
2020
'deep learning',
21-
'text to image'
21+
'text-to-image',
22+
'pytorch'
2223
]
2324
)

0 commit comments

Comments
 (0)