Skip to content

Commit 131e3e5

Browse files
committed
kandinsky
1 parent 08c18de commit 131e3e5

File tree

11 files changed

+102
-121
lines changed

11 files changed

+102
-121
lines changed

README.md

+12-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
99

1010
```
11-
pip install rudalle==1.1.0rc0
11+
pip install rudalle==1.1.0
1212
```
1313
### 🤗 HF Models:
1414
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) \
1515
[ruDALL-E Emojich (XL)](https://huggingface.co/sberbank-ai/rudalle-Emojich) (readme [here](https://github.com/sberbank-ai/ru-dalle/blob/master/Emojich.md)) \
16-
[ruDALL-E Surrealist (XL)](https://huggingface.co/shonenkov-AI/rudalle-xl-surrealist)
17-
16+
[ruDALL-E Surrealist (XL)](https://huggingface.co/shonenkov-AI/rudalle-xl-surrealist) \
17+
ruDALL-E Kandinsky (XXL) (soon)
1818

1919
### Minimal Example:
2020

@@ -100,6 +100,15 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
100100
![](https://raw.githubusercontent.com/shonenkov-AI/rudalle-aspect-ratio/main/pics/h_example.jpg)
101101

102102

103+
### [Kandinsky]()
104+
105+
`роботы акварелью в стиле ван гога`
106+
![](./pics/kandinsky/example-robots.png)
107+
108+
[![](./pics/habr_eng.svg)](https://habr.com/ru/company/sberbank/blog/671210/)
109+
110+
![](./pics/kandinsky/loss.jpg)
111+
`FID = 15.4 (COCO Valid)`
103112

104113
### 🚀 Contributors 🚀
105114

pics/kandinsky/example-robots.png

1.25 MB
Loading

pics/kandinsky/loss.jpg

23.7 KB
Loading

rudalle/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
'image_prompts',
2222
]
2323

24-
__version__ = '1.1.0rc0'
24+
__version__ = '1.1.0'

rudalle/dalle/__init__.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
),
3030
repo_id='sberbank-ai/rudalle-Malevich',
3131
filename='pytorch_model_v3.bin',
32-
authors='SberAI, SberDevices',
32+
authors='SberAI, SberDevices, shonenkovAI',
3333
full_description='', # TODO
3434
),
3535
'Malevich_v2': dict(
@@ -52,12 +52,12 @@
5252
),
5353
repo_id='sberbank-ai/rudalle-Malevich',
5454
filename='pytorch_model_v2.bin',
55-
authors='SberAI, SberDevices',
55+
authors='SberAI, SberDevices, shonenkovAI',
5656
full_description='', # TODO
5757
),
5858
'Emojich': dict(
5959
hf_version='v2',
60-
description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, '
60+
description='😋 Emojich is 1.3 billion params model from the family GPT3-like, '
6161
'it generates emoji-style images with the brain of ◾ Malevich.',
6262
model_params=dict(
6363
num_layers=24,
@@ -75,9 +75,32 @@
7575
),
7676
repo_id='sberbank-ai/rudalle-Emojich',
7777
filename='pytorch_model.bin',
78-
authors='SberAI',
78+
authors='SberAI, SberDevices, shonenkovAI',
7979
full_description='', # TODO
8080
),
81+
'Surrealist_XL': dict(
82+
hf_version='v3',
83+
description='Surrealist is 1.3 billion params model from the family GPT3-like, '
84+
'that was trained on surrealism and Russian.',
85+
model_params=dict(
86+
num_layers=24,
87+
hidden_size=2048,
88+
num_attention_heads=16,
89+
embedding_dropout_prob=0.1,
90+
output_dropout_prob=0.1,
91+
attention_dropout_prob=0.1,
92+
image_tokens_per_dim=32,
93+
text_seq_length=128,
94+
cogview_sandwich_layernorm=True,
95+
cogview_pb_relax=True,
96+
vocab_size=16384 + 128,
97+
image_vocab_size=8192,
98+
),
99+
repo_id='shonenkov-AI/rudalle-xl-surrealist',
100+
filename='pytorch_model.bin',
101+
authors='shonenkovAI',
102+
full_description='',
103+
),
81104
'Kandinsky': dict(
82105
hf_version='v3',
83106
description='Kandinsky is large 12 billion params model from the family GPT3-like, '
@@ -93,17 +116,16 @@
93116
text_seq_length=128,
94117
cogview_sandwich_layernorm=True,
95118
cogview_pb_relax=True,
96-
cogview_layernorm_prescale=True,
97-
custom_relax=True,
98119
vocab_size=16384 + 128,
99120
image_vocab_size=8192,
100121
),
101-
repo_id='',
102-
filename='',
103-
authors='SberAI, SberDevices',
122+
repo_id='shonenkov-AI/Kandinsky',
123+
filename='pytorch_model.bin',
124+
authors='SberAI, SberDevices, shonenkovAI',
104125
full_description='', # TODO
105126
),
106127
'dummy': dict(
128+
hf_version='v3',
107129
description='',
108130
model_params=dict(
109131
num_layers=12,
@@ -126,20 +148,21 @@
126148
}
127149

128150

129-
def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs):
130-
# TODO docstring
151+
def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', use_auth_token=None,
152+
cache_dir='/tmp/rudalle', **model_kwargs):
131153
assert name in MODELS
132154

133155
if fp16 and device == 'cpu':
134156
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
135157

136158
config = MODELS[name].copy()
137159
config['model_params'].update(model_kwargs)
138-
model = DalleModel(device=device, **config['model_params'])
160+
model = DalleModel(device=device, hf_version=config['hf_version'], **config['model_params'])
139161
if pretrained:
140162
cache_dir = os.path.join(cache_dir, name)
141163
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
142-
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
164+
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'],
165+
use_auth_token=use_auth_token)
143166
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
144167
model.load_state_dict(checkpoint)
145168
if fp16:

rudalle/dalle/model.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def __init__(self,
2424
loss_img_weight=7,
2525
cogview_sandwich_layernorm=False,
2626
cogview_pb_relax=False,
27-
cogview_layernorm_prescale=False,
28-
custom_relax=False,
2927
is_bool_mask=True,
3028
mlp_activation='gelu_jit',
3129
hf_version='v3'):
@@ -73,8 +71,6 @@ def __init__(self,
7371
image_tokens_per_dim=image_tokens_per_dim,
7472
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
7573
cogview_pb_relax=cogview_pb_relax,
76-
cogview_layernorm_prescale=cogview_layernorm_prescale,
77-
custom_relax=custom_relax,
7874
mlp_activation=mlp_activation,
7975
is_bool_mask=is_bool_mask,
8076
hf_version=self.hf_version,
@@ -110,13 +106,13 @@ def forward(
110106
if self.hf_version == 'v2':
111107
# some hardcode :)
112108
text = F.pad(text, (1, 0), value=2)
113-
text_pos = self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
114-
text_embeddings = self.text_embeddings(text) + text_pos
109+
text_embeddings = self.text_embeddings(text) + \
110+
self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
115111
image_input_ids = input_ids[:, self.text_seq_length:]
116112

117113
if exists(image_input_ids) and not is_empty(image_input_ids):
118-
img_pos = self.get_image_pos_embeddings(image_input_ids, past_length=0)
119-
image_embeddings = self.image_embeddings(image_input_ids) + img_pos
114+
image_embeddings = self.image_embeddings(image_input_ids) + \
115+
self.get_image_pos_embeddings(image_input_ids, past_length=0)
120116
embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
121117
else:
122118
embeddings = text_embeddings

0 commit comments

Comments
 (0)