|
1 | 1 | # -*- coding: utf-8 -*-
|
| 2 | +import os |
| 3 | +from glob import glob |
| 4 | +from os.path import join |
| 5 | + |
2 | 6 | import torch
|
3 | 7 | import torchvision
|
4 | 8 | import transformers
|
@@ -34,10 +38,10 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
|
34 | 38 | sample_scores = []
|
35 | 39 | if image_prompts is not None:
|
36 | 40 | prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
|
37 |
| - prompts = prompts.repeat(images_num, 1) |
38 |
| - if use_cache: |
39 |
| - use_cache = False |
| 41 | + prompts = prompts.repeat(chunk_bs, 1) |
| 42 | + if use_cache and image_prompts.allow_cache is False: |
40 | 43 | print('Warning: use_cache changed to False')
|
| 44 | + use_cache = False |
41 | 45 | for idx in tqdm(range(out.shape[1], total_seq_length)):
|
42 | 46 | idx -= text_seq_length
|
43 | 47 | if image_prompts is not None and idx in prompts_idx:
|
@@ -84,15 +88,31 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
|
84 | 88 | return top_pil_images, top_scores
|
85 | 89 |
|
86 | 90 |
|
87 |
| -def show(pil_images, nrow=4): |
| 91 | +def show(pil_images, nrow=4, save_dir=None, show=True): |
| 92 | + """ |
| 93 | + :param pil_images: list of images in PIL |
| 94 | + :param nrow: number of rows |
| 95 | + :param save_dir: dir for separately saving of images, example: save_dir='./pics' |
| 96 | + """ |
| 97 | + if save_dir is not None: |
| 98 | + os.makedirs(save_dir, exist_ok=True) |
| 99 | + count = len(glob(join(save_dir, 'img_*.png'))) |
| 100 | + for i, pil_image in enumerate(pil_images): |
| 101 | + pil_image.save(join(save_dir, f'img_{count+i}.png')) |
| 102 | + |
88 | 103 | imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
|
89 | 104 | if not isinstance(imgs, list):
|
90 | 105 | imgs = [imgs.cpu()]
|
91 | 106 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
|
92 | 107 | for i, img in enumerate(imgs):
|
93 | 108 | img = img.detach()
|
94 | 109 | img = torchvision.transforms.functional.to_pil_image(img)
|
95 |
| - axs[0, i].imshow(np.asarray(img)) |
96 |
| - axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
97 |
| - fix.show() |
98 |
| - plt.show() |
| 110 | + if save_dir is not None: |
| 111 | + count = len(glob(join(save_dir, 'group_*.png'))) |
| 112 | + img.save(join(save_dir, f'group_{count+i}.png')) |
| 113 | + if show: |
| 114 | + axs[0, i].imshow(np.asarray(img)) |
| 115 | + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
| 116 | + if show: |
| 117 | + fix.show() |
| 118 | + plt.show() |
0 commit comments