Skip to content

Smart border inpainting mode #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions javascript/deforum-hints.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ deforum_titles = {
"Border": "controls handling method of pixels to be generated when the image is smaller than the frame.",
"wrap": "pulls pixels from the opposite edge of the image",
"replicate": "repeats the edge of the pixels, and extends them. Animations with quick motion may yield lines where this border function was attempting to populate pixels into the empty space created.",
"smart": "makes a second 'inpaint' pass to fill empty space when zooming out or shifting the frame; useful to fix the 'border stripes' and make the 'sidespace' feel more open; overrides padding_mode",
"Angle": "2D operator to rotate canvas clockwise/anticlockwise in degrees per frame",
"Zoom": "2D operator that scales the canvas size, multiplicatively. [static = 1.0]",
"Translation X": "2D & 3D operator to move canvas left/right in pixels per frame",
Expand Down
59 changes: 46 additions & 13 deletions scripts/deforum_helpers/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def flip_3d_perspective(anim_args, prev_img_cv2, keys, frame_idx):

def anim_frame_warp(prev_img_cv2, args, anim_args, keys, frame_idx, depth_model=None, depth=None, device='cuda', half_precision = False):

warp_mask = None

if anim_args.use_depth_warping:
if depth is None and depth_model is not None:
depth = depth_model.predict(prev_img_cv2, anim_args, half_precision)
Expand All @@ -166,9 +168,8 @@ def anim_frame_warp(prev_img_cv2, args, anim_args, keys, frame_idx, depth_model=
if anim_args.animation_mode == '2D':
prev_img = anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx)
else: # '3D'
prev_img = anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx)

return prev_img, depth
prev_img, warp_mask = anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx)
return prev_img, depth, warp_mask

def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
angle = keys.angle_series[frame_idx]
Expand All @@ -183,15 +184,32 @@ def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
rot_mat = np.vstack([rot_mat, [0,0,1]])
if anim_args.enable_perspective_flip:
bM = get_flip_perspective_matrix(args.W, args.H, keys, frame_idx)
rot_mat = np.matmul(bM, rot_mat, trans_mat)
xform = np.matmul(bM, rot_mat, trans_mat)
else:
rot_mat = np.matmul(rot_mat, trans_mat)
return cv2.warpPerspective(
prev_img_cv2,
rot_mat,
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
)
xform = np.matmul(rot_mat, trans_mat)

borderMode = cv2.BORDER_CONSTANT #zeros

if anim_args.border == 'wrap':
borderMode = cv2.BORDER_WRAP
elif anim_args.border == 'replicate':
borderMode = cv2.BORDER_REPLICATE

if borderMode == 'smart':
return cv2.warpPerspective(
prev_img_cv2,
xform,
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
borderMode=borderMode,
borderValue=(0, 0, 0,),
)
else:
return cv2.warpPerspective(
prev_img_cv2,
xform,
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
borderMode=borderMode,
)

def anim_frame_warp_3d(device, prev_img_cv2, depth, anim_args, keys, frame_idx):
TRANSLATION_SCALE = 1.0/200.0 # matches Disco
Expand Down Expand Up @@ -246,7 +264,7 @@ def transform_image_3d(device, prev_img_cv2, depth_tensor, rot_mat, translate, a
image_tensor.add(1/512 - 0.0001).unsqueeze(0),
offset_coords_2d,
mode=anim_args.sampling_mode,
padding_mode=anim_args.padding_mode,
padding_mode=anim_args.padding_mode if anim_args.border is not 'smart' else 'zeros', # border overrides padding_mode without changing the settings saved
align_corners=False
)

Expand All @@ -255,4 +273,19 @@ def transform_image_3d(device, prev_img_cv2, depth_tensor, rot_mat, translate, a
new_image.squeeze().clamp(0,255),
'c h w -> h w c'
).cpu().numpy().astype(prev_img_cv2.dtype)
return result

warp_mask = torch.ones_like(image_tensor) * 255
warp_mask_image = torch.nn.functional.grid_sample(
warp_mask.add(1/512 - 0.0001).unsqueeze(0),
offset_coords_2d,
mode="nearest",
padding_mode="zeros",
align_corners=False
)

result_mask = rearrange(
warp_mask_image.squeeze().clamp(0,255),
'c h w -> h w c'
).cpu().numpy().astype(prev_img_cv2.dtype)

return result, result_mask
14 changes: 10 additions & 4 deletions scripts/deforum_helpers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def Root():
first_frame = None
outpath_samples = ""
animation_prompts = None
color_corrections = None
color_corrections = None
warp_mask = None
initial_clipskip = None
current_user_os = get_os()
return locals()
Expand All @@ -32,7 +33,8 @@ def DeforumAnimArgs():
#@markdown ####**Animation:**
animation_mode = '2D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
max_frames = 120 #@param {type:"number"}
border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}
border = 'smart' #@param ['wrap', 'replicate', 'smart'] {type:'string'}

#@markdown ####**Motion Parameters:**
angle = "0:(0)"#@param {type:"string"}
zoom = "0:(1.0025+0.002*sin(1.25*3.14*t/30))"#@param {type:"string"}
Expand Down Expand Up @@ -222,6 +224,7 @@ def DeforumArgs():
full_res_mask = True
full_res_mask_padding = 4
reroll_blank_frames = 'reroll' # reroll, interrupt, or ignore
smart_border_fill_mode = 2

n_samples = 1 # doesnt do anything
precision = 'autocast'
Expand Down Expand Up @@ -384,7 +387,7 @@ def show_vid():
with gr.Column(scale=2):
animation_mode = gr.Radio(['2D', '3D', 'Interpolation', 'Video Input'], label="Animation mode", value=da.animation_mode, elem_id="animation_mode")
with gr.Column(scale=1, min_width=180):
border = gr.Radio(['replicate', 'wrap'], label="Border", value=da.border, elem_id="border")
border = gr.Radio(['replicate', 'wrap', 'smart'], label="Border", value=da.border, elem_id="border")
with gr.Row(variant='compact'):
with gr.Column(scale=5):
with gr.Row(variant='compact'):
Expand Down Expand Up @@ -635,6 +638,9 @@ def show_vid():
with gr.Row():
choice = mask_fill_choices[d.fill]
fill = gr.Radio(label='Mask fill', choices=mask_fill_choices, value=choice, type="index")
with gr.Row():
choice = mask_fill_choices[d.smart_border_fill_mode]
smart_border_fill_mode = gr.Radio(label='Smart border fill', choices=mask_fill_choices, value=choice, type="index")
with gr.Row():
full_res_mask = gr.Checkbox(label="Full res mask", value=d.full_res_mask, interactive=True)
full_res_mask_padding = gr.Slider(minimum=0, maximum=512, step=1, label="Full res mask padding", value=d.full_res_mask_padding, interactive=True)
Expand Down Expand Up @@ -1036,7 +1042,7 @@ def update_upscale_out_res_by_model_name(in_res, upscale_model_name):
use_init, from_img2img_instead_of_link, strength_0_no_init, strength, init_image,
use_mask, use_alpha_as_mask, invert_mask, overlay_mask,
mask_file, mask_contrast_adjust, mask_brightness_adjust, mask_overlay_blur,
fill, full_res_mask, full_res_mask_padding,
fill, smart_border_fill_mode, full_res_mask, full_res_mask_padding,
reroll_blank_frames'''
).replace("\n", "").replace("\r", "").replace(" ", "").split(',')
video_args_names = str(r'''skip_video_for_run_all,
Expand Down
74 changes: 73 additions & 1 deletion scripts/deforum_helpers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#Webui
import cv2
from .animation import sample_from_cv2, sample_to_cv2
from modules import processing, sd_models
from modules import processing, sd_models, masking
from modules.shared import opts, sd_model
import modules.shared as shared
from modules.processing import process_images, StableDiffusionProcessingTxt2Img

import math, json, itertools
Expand Down Expand Up @@ -140,6 +141,77 @@ def generate(args, anim_args, loop_args, root, frame = 0, return_sample=False, s
init_image = Image.blend(init_image, init_image2, blendFactor)
correction_colors = Image.blend(init_image, init_image2, colorCorrectionFactor)
p.color_corrections = [processing.setup_color_correction(correction_colors)]

if anim_args.border == 'smart':

# Inpaint changed parts of the image
# that's, to say, zeros we got after the transformations

# Its important to note that the loop below is creating a mask for inpainting 0's
# This mask however can mask areas that were intended to be black
# Suggest a fix to send the inpainting mask as an argument,
# before the add_noise and contrast_adjust is applied
mask_image = init_image.convert('L')
for x in range(mask_image.width):
for y in range(mask_image.height):
if mask_image.getpixel((x,y)) < 4:
mask_image.putpixel((x,y), 255)
else:
mask_image.putpixel((x,y), 0)

# blend the two masks
if root.warp_mask is not None:
# TODO: I guess there is some built-in function for this
warp_mask_image = Image.fromarray(root.warp_mask).convert('L')
for x in range(mask_image.width):
for y in range(mask_image.height):
if mask_image.getpixel((x,y)) > 0 or warp_mask_image.getpixel((x,y)) == 0:
mask_image.putpixel((x,y), 255)
else:
mask_image.putpixel((x,y), 0)
root.warp_mask = None

mask = prepare_mask(mask_image,
(args.W, args.H),
args.mask_contrast_adjust,
args.mask_brightness_adjust)

# HACK: this is a hacky check to make the mask work with the new inpainting code
crop_region = masking.get_crop_region(np.array(mask_image), args.full_res_mask_padding)
crop_region = masking.expand_crop_region(crop_region, args.W, args.H, mask_image.width, mask_image.height)
x1, y1, x2, y2 = crop_region

too_small = (x2 - x1) < 1 or (y2 - y1) < 1

if not too_small:
p.do_not_save_samples=True,
p.inpainting_fill = args.smart_border_fill_mode
p.inpaint_full_res= args.full_res_mask
p.inpaint_full_res_padding = args.full_res_mask_padding
p.init_images = [init_image]
p.image_mask = mask_image

#color correction for zeroes inpainting
p.color_corrections = [processing.setup_color_correction(init_image)]

print("Smart mode: inpainting border")

processed = processing.process_images(p)
init_image = processed.images[0].convert('RGB')

p = get_webui_sd_pipeline(args, root, frame)
p.init_images = [init_image]

processed = None
else:
# fix tqdm total steps if we don't have to conduct a second pass
tqdm_instance = shared.total_tqdm
current_total = tqdm_instance.getTotal()
if current_total != -1:
tqdm_instance.updateTotal(current_total - int(math.ceil(args.steps * (1-args.strength))))

mask = None
mask_image = None

# this is the first pass
elif loop_args.use_looper or (args.use_init and ((args.init_image != None and args.init_image != ''))):
Expand Down
7 changes: 4 additions & 3 deletions scripts/deforum_helpers/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, animat
depth = depth_model.predict(turbo_next_image, anim_args, root.half_precision)

if advance_prev:
turbo_prev_image, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision)
turbo_prev_image, _, _ = anim_frame_warp(turbo_prev_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision)
if advance_next:
turbo_next_image, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision)
turbo_next_image, _, _ = anim_frame_warp(turbo_next_image, args, anim_args, keys, tween_frame_idx, depth_model, depth=depth, device=root.device, half_precision=root.half_precision)

# hybrid video motion - warps turbo_prev_image or turbo_next_image to match motion
if tween_frame_idx > 0:
Expand Down Expand Up @@ -276,7 +276,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, animat

# apply transforms to previous frame
if prev_img is not None:
prev_img, depth = anim_frame_warp(prev_img, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device, half_precision=root.half_precision)
prev_img, depth, warp_mask = anim_frame_warp(prev_img, args, anim_args, keys, frame_idx, depth_model, depth=None, device=root.device, half_precision=root.half_precision)

# hybrid video motion - warps prev_img to match motion, usually to prepare for compositing
if frame_idx > 0:
Expand Down Expand Up @@ -333,6 +333,7 @@ def render_animation(args, anim_args, video_args, parseq_args, loop_args, animat
args.use_init = True
args.init_sample = Image.fromarray(cv2.cvtColor(noised_image, cv2.COLOR_BGR2RGB))
args.strength = max(0.0, min(1.0, strength))
root.warp_mask = warp_mask

args.scale = scale

Expand Down
22 changes: 22 additions & 0 deletions scripts/deforum_helpers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ def load_settings(*args, **kwargs):
logging.debug(f"Fill not found in load file, using default value: {fill_default}")
ret.append(mask_fill_choices[fill_default])

elif key == 'smart_border_fill_mode':
if key in jdata:
smart_border_fill_mode_val = jdata[key]
if type(smart_border_fill_mode_val) == int:
ret.append(mask_fill_choices[smart_border_fill_mode_val])
else:
ret.append(smart_border_fill_mode_val)
else:
smart_border_fill_mode_default = DeforumArgs()['smart_border_fill_mode']
logging.debug(f"Smart border fill mode not found in load file, using default value: {smart_border_fill_mode_default}")
ret.append(mask_fill_choices[smart_border_fill_mode_default])

elif key == 'reroll_blank_frames':
if key in jdata:
reroll_blank_frames_val = jdata[key]
Expand Down Expand Up @@ -237,6 +249,9 @@ def reset(self):
deforum_total += self._args.steps
had_first = True
else:
#duplicate steps count in smart border mode
if self._anim_args.border == 'smart':
deforum_total += int(ceil(self._args.steps * (1-strength)))
deforum_total += int(ceil(self._args.steps * (1-strength)))

if turbo_steps > 1:
Expand Down Expand Up @@ -269,3 +284,10 @@ def clear(self):
if self._tqdm is not None:
self._tqdm.close()
self._tqdm = None

def getTotal(self):
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return -1
if self._tqdm is None:
self.reset()
return self._tqdm.total