|
29 | 29 | ),
|
30 | 30 | repo_id='sberbank-ai/rudalle-Malevich',
|
31 | 31 | filename='pytorch_model_v3.bin',
|
32 |
| - authors='SberAI, SberDevices', |
| 32 | + authors='SberAI, SberDevices, shonenkovAI', |
33 | 33 | full_description='', # TODO
|
34 | 34 | ),
|
35 | 35 | 'Malevich_v2': dict(
|
|
52 | 52 | ),
|
53 | 53 | repo_id='sberbank-ai/rudalle-Malevich',
|
54 | 54 | filename='pytorch_model_v2.bin',
|
55 |
| - authors='SberAI, SberDevices', |
| 55 | + authors='SberAI, SberDevices, shonenkovAI', |
56 | 56 | full_description='', # TODO
|
57 | 57 | ),
|
58 | 58 | 'Emojich': dict(
|
59 | 59 | hf_version='v2',
|
60 |
| - description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, ' |
| 60 | + description='😋 Emojich is 1.3 billion params model from the family GPT3-like, ' |
61 | 61 | 'it generates emoji-style images with the brain of ◾ Malevich.',
|
62 | 62 | model_params=dict(
|
63 | 63 | num_layers=24,
|
|
75 | 75 | ),
|
76 | 76 | repo_id='sberbank-ai/rudalle-Emojich',
|
77 | 77 | filename='pytorch_model.bin',
|
78 |
| - authors='SberAI', |
| 78 | + authors='SberAI, SberDevices, shonenkovAI', |
79 | 79 | full_description='', # TODO
|
80 | 80 | ),
|
| 81 | + 'Surrealist_XL': dict( |
| 82 | + hf_version='v3', |
| 83 | + description='Surrealist is 1.3 billion params model from the family GPT3-like, ' |
| 84 | + 'that was trained on surrealism and Russian.', |
| 85 | + model_params=dict( |
| 86 | + num_layers=24, |
| 87 | + hidden_size=2048, |
| 88 | + num_attention_heads=16, |
| 89 | + embedding_dropout_prob=0.1, |
| 90 | + output_dropout_prob=0.1, |
| 91 | + attention_dropout_prob=0.1, |
| 92 | + image_tokens_per_dim=32, |
| 93 | + text_seq_length=128, |
| 94 | + cogview_sandwich_layernorm=True, |
| 95 | + cogview_pb_relax=True, |
| 96 | + vocab_size=16384 + 128, |
| 97 | + image_vocab_size=8192, |
| 98 | + ), |
| 99 | + repo_id='shonenkov-AI/rudalle-xl-surrealist', |
| 100 | + filename='pytorch_model.bin', |
| 101 | + authors='shonenkovAI', |
| 102 | + full_description='', |
| 103 | + ), |
81 | 104 | 'Kandinsky': dict(
|
82 | 105 | hf_version='v3',
|
83 | 106 | description='Kandinsky is large 12 billion params model from the family GPT3-like, '
|
|
93 | 116 | text_seq_length=128,
|
94 | 117 | cogview_sandwich_layernorm=True,
|
95 | 118 | cogview_pb_relax=True,
|
96 |
| - cogview_layernorm_prescale=True, |
97 |
| - custom_relax=True, |
98 | 119 | vocab_size=16384 + 128,
|
99 | 120 | image_vocab_size=8192,
|
100 | 121 | ),
|
101 |
| - repo_id='', |
102 |
| - filename='', |
103 |
| - authors='SberAI, SberDevices', |
| 122 | + repo_id='shonenkov-AI/Kandinsky', |
| 123 | + filename='pytorch_model.bin', |
| 124 | + authors='SberAI, SberDevices, shonenkovAI', |
104 | 125 | full_description='', # TODO
|
105 | 126 | ),
|
106 | 127 | 'dummy': dict(
|
| 128 | + hf_version='v3', |
107 | 129 | description='',
|
108 | 130 | model_params=dict(
|
109 | 131 | num_layers=12,
|
|
126 | 148 | }
|
127 | 149 |
|
128 | 150 |
|
129 |
| -def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir='/tmp/rudalle', **model_kwargs): |
130 |
| - # TODO docstring |
| 151 | +def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', use_auth_token=None, |
| 152 | + cache_dir='/tmp/rudalle', **model_kwargs): |
131 | 153 | assert name in MODELS
|
132 | 154 |
|
133 | 155 | if fp16 and device == 'cpu':
|
134 | 156 | print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
|
135 | 157 |
|
136 | 158 | config = MODELS[name].copy()
|
137 | 159 | config['model_params'].update(model_kwargs)
|
138 |
| - model = DalleModel(device=device, **config['model_params']) |
| 160 | + model = DalleModel(device=device, hf_version=config['hf_version'], **config['model_params']) |
139 | 161 | if pretrained:
|
140 | 162 | cache_dir = os.path.join(cache_dir, name)
|
141 | 163 | config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
|
142 |
| - cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) |
| 164 | + cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'], |
| 165 | + use_auth_token=use_auth_token) |
143 | 166 | checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
|
144 | 167 | model.load_state_dict(checkpoint)
|
145 | 168 | if fp16:
|
|
0 commit comments