Skip to content

Commit 9589235

Browse files
authored
some fixes (#30)
* some fixes, fill readme * edit readme, up version
1 parent d3e3c69 commit 9589235

File tree

5 files changed

+48
-16
lines changed

5 files changed

+48
-16
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
77

88
```
9-
pip install rudalle==0.0.1rc4
9+
pip install rudalle==0.0.1rc5
1010
```
1111
### 🤗 HF Models:
1212
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
@@ -18,13 +18,12 @@ pip install rudalle==0.0.1rc4
1818
[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/shonenkov/rudalle-example-generation)
1919
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/anton-l/rudall-e)
2020

21-
**English translation example**
22-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)
23-
24-
2521
**Finetuning example**
2622
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tb7J4PvvegWOybPfUubl5O7m5I24CBg5?usp=sharing)
2723

24+
**English translation example**
25+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12fbO6YqtzHAHemY2roWQnXvKkdidNQKO?usp=sharing)
26+
2827
### generation by ruDALLE:
2928
```python
3029
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
@@ -95,4 +94,5 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
9594
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
9695
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
9796
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt
97+
- [@Alex Wortega](https://github.com/AlexWortega) created first FREE version colab notebook with fine-tuning [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) on sneakers domain 💪
9898
- [@Anton Lozhkov](https://github.com/anton-l) Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio), see [here](https://huggingface.co/spaces/anton-l/rudall-e)

rudalle/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@
2222
'image_prompts',
2323
]
2424

25-
__version__ = '0.0.1-rc4'
25+
__version__ = '0.0.1-rc5'

rudalle/image_prompts.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,25 @@ def __init__(self, pil_image, borders, vae, device='cpu', crop_first=False):
1818
self.device = device
1919
img = self._preprocess_img(pil_image)
2020
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
21+
self.allow_cache = True
2122

2223
def _preprocess_img(self, pil_img):
2324
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
2425
img = img.unsqueeze(0).to(self.device, dtype=torch.float32)
2526
img = (2 * img) - 1
2627
return img
2728

28-
@staticmethod
29-
def _get_image_prompts(img, borders, vae, crop_first):
29+
def _get_image_prompts(self, img, borders, vae, crop_first):
3030
if crop_first:
3131
assert borders['right'] + borders['left'] + borders['down'] == 0
3232
up_border = borders['up'] * 8
3333
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
3434
else:
3535
_, _, [_, _, vqg_img] = vae.model.encode(img)
3636

37+
if borders['right'] + borders['left'] + borders['down'] != 0:
38+
self.allow_cache = False # TODO fix cache in attention
39+
3740
bs, vqg_img_w, vqg_img_h = vqg_img.shape
3841
mask = torch.zeros(vqg_img_w, vqg_img_h)
3942
if borders['up'] != 0:

rudalle/pipelines.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# -*- coding: utf-8 -*-
2+
import os
3+
from glob import glob
4+
from os.path import join
5+
26
import torch
37
import torchvision
48
import transformers
@@ -34,10 +38,10 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
3438
sample_scores = []
3539
if image_prompts is not None:
3640
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
37-
prompts = prompts.repeat(images_num, 1)
38-
if use_cache:
39-
use_cache = False
41+
prompts = prompts.repeat(chunk_bs, 1)
42+
if use_cache and image_prompts.allow_cache is False:
4043
print('Warning: use_cache changed to False')
44+
use_cache = False
4145
for idx in tqdm(range(out.shape[1], total_seq_length)):
4246
idx -= text_seq_length
4347
if image_prompts is not None and idx in prompts_idx:
@@ -84,15 +88,31 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
8488
return top_pil_images, top_scores
8589

8690

87-
def show(pil_images, nrow=4):
91+
def show(pil_images, nrow=4, save_dir=None, show=True):
92+
"""
93+
:param pil_images: list of images in PIL
94+
:param nrow: number of rows
95+
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
96+
"""
97+
if save_dir is not None:
98+
os.makedirs(save_dir, exist_ok=True)
99+
count = len(glob(join(save_dir, 'img_*.png')))
100+
for i, pil_image in enumerate(pil_images):
101+
pil_image.save(join(save_dir, f'img_{count+i}.png'))
102+
88103
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
89104
if not isinstance(imgs, list):
90105
imgs = [imgs.cpu()]
91106
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
92107
for i, img in enumerate(imgs):
93108
img = img.detach()
94109
img = torchvision.transforms.functional.to_pil_image(img)
95-
axs[0, i].imshow(np.asarray(img))
96-
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
97-
fix.show()
98-
plt.show()
110+
if save_dir is not None:
111+
count = len(glob(join(save_dir, 'group_*.png')))
112+
img.save(join(save_dir, f'group_{count+i}.png'))
113+
if show:
114+
axs[0, i].imshow(np.asarray(img))
115+
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
116+
if show:
117+
fix.show()
118+
plt.show()

tests/test_show.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# -*- coding: utf-8 -*-
2+
from rudalle.pipelines import show
3+
4+
5+
def test_show(sample_image):
6+
img = sample_image.copy()
7+
img = img.resize((256, 256))
8+
pil_images = [img]*5
9+
show(pil_images, nrow=2, save_dir='/tmp/pics', show=False)

0 commit comments

Comments
 (0)