@@ -85,10 +85,11 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
85
85
return top_pil_images , top_scores
86
86
87
87
88
- def show (pil_images , nrow = 4 , save_dir = None , show = True ):
88
+ def show (pil_images , nrow = 4 , size = 14 , save_dir = None , show = True ):
89
89
"""
90
90
:param pil_images: list of images in PIL
91
91
:param nrow: number of rows
92
+ :param size: size of the images
92
93
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
93
94
"""
94
95
if save_dir is not None :
@@ -100,7 +101,7 @@ def show(pil_images, nrow=4, save_dir=None, show=True):
100
101
imgs = torchvision .utils .make_grid (utils .pil_list_to_torch_tensors (pil_images ), nrow = nrow )
101
102
if not isinstance (imgs , list ):
102
103
imgs = [imgs .cpu ()]
103
- fix , axs = plt .subplots (ncols = len (imgs ), squeeze = False , figsize = (14 , 14 ))
104
+ fix , axs = plt .subplots (ncols = len (imgs ), squeeze = False , figsize = (size , size ))
104
105
for i , img in enumerate (imgs ):
105
106
img = img .detach ()
106
107
img = torchvision .transforms .functional .to_pil_image (img )
0 commit comments