Skip to content

Commit a23a834

Browse files
authored
Add size param to show() (#40)
1 parent 8d762f5 commit a23a834

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

Diff for: rudalle/pipelines.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
8585
return top_pil_images, top_scores
8686

8787

88-
def show(pil_images, nrow=4, save_dir=None, show=True):
88+
def show(pil_images, nrow=4, size=14, save_dir=None, show=True):
8989
"""
9090
:param pil_images: list of images in PIL
9191
:param nrow: number of rows
92+
:param size: size of the images
9293
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
9394
"""
9495
if save_dir is not None:
@@ -100,7 +101,7 @@ def show(pil_images, nrow=4, save_dir=None, show=True):
100101
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
101102
if not isinstance(imgs, list):
102103
imgs = [imgs.cpu()]
103-
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
104+
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
104105
for i, img in enumerate(imgs):
105106
img = img.detach()
106107
img = torchvision.transforms.functional.to_pil_image(img)

0 commit comments

Comments
 (0)