Skip to content

Commit 4963ea7

Browse files
committed
correct tqdm when inpainting pass is skipped
1 parent bb34562 commit 4963ea7

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

scripts/deforum/generate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from PIL import Image, ImageOps
44
import requests
55
import numpy as np
6+
from math import ceil
67
import torchvision.transforms.functional as TF
78
from pytorch_lightning import seed_everything
89
import os
@@ -23,6 +24,7 @@
2324
import cv2
2425
from .animation import sample_from_cv2, sample_to_cv2
2526
from modules import processing, masking
27+
import modules.shared as shared
2628
from modules.shared import opts, sd_model
2729
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
2830

@@ -185,6 +187,8 @@ def generate(args, root, frame = 0, return_sample=False):
185187
#color correction for zeroes inpainting
186188
p.color_corrections = [processing.setup_color_correction(init_image)]
187189

190+
print("Inpainting zeros")
191+
188192
processed = processing.process_images(p)
189193

190194
init_image = processed.images[0].convert('RGB')
@@ -193,6 +197,13 @@ def generate(args, root, frame = 0, return_sample=False):
193197
p.image_mask = None
194198
mask_image = None
195199
processed = None
200+
else:
201+
# fix tqdm total steps if we don't have to conduct a second pass
202+
tqdm_instance = shared.total_tqdm
203+
current_total = tqdm_instance.getTotal()
204+
if current_total != -1:
205+
tqdm_instance.updateTotal(current_total - int(ceil(args.steps * (1-args.strength))))
206+
196207

197208
elif args.use_init and args.init_image != None and args.init_image != '':
198209
init_image, mask_image = load_img(args.init_image,

scripts/deforum/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,10 @@ def clear(self):
200200
if self._tqdm is not None:
201201
self._tqdm.close()
202202
self._tqdm = None
203+
204+
def getTotal(self):
205+
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
206+
return -1
207+
if self._tqdm is None:
208+
self.reset()
209+
return self._tqdm.total

0 commit comments

Comments
 (0)