Skip to content

Commit ac22fd2

Browse files
authored
Merge pull request #9 from Sxela/alex/combine-nodes
Alex/combine nodes - 0.6.0
2 parents 9bb3075 + 0709c7e commit ac22fd2

File tree

6 files changed

+281
-72
lines changed

6 files changed

+281
-72
lines changed

flow_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ def get_flow_and_mask(frame1, frame2, num_flow_updates=20, raft_model=None, edge
251251
occlusion_mask, _ = get_unreliable(predicted_flows)
252252
_, overshoot = get_unreliable(predicted_flows_bwd)
253253

254-
occlusion_mask = (torch.from_numpy(255-(filter_unreliable(occlusion_mask, dilation)*255)).transpose(0,1)/255).cpu()
255-
border_mask = (torch.from_numpy(overshoot*255).transpose(0,1)/255).cpu()
256-
edge_mask = (torch.from_numpy(255-edge).transpose(0,1)/255).cpu()
254+
occlusion_mask = (torch.from_numpy(255-(filter_unreliable(occlusion_mask, dilation)*255)).transpose(0,1)/255).cpu()[None,...]
255+
border_mask = (torch.from_numpy(overshoot*255).transpose(0,1)/255).cpu()[None,...]
256+
edge_mask = (torch.from_numpy(255-edge).transpose(0,1)/255).cpu()[None,...]
257257
print(flow_imgs.max(), flow_imgs.min())
258258
flow_imgs = (torch.from_numpy(flow_imgs.transpose(1,0,2))/255).cpu()[None,]
259259
raft_model.cpu()
@@ -291,9 +291,9 @@ def apply_warp(current_frame, flow, padding=0):
291291
def mix_cc(missed_cc, overshoot_cc, edge_cc, blur=2, dilate=0, missed_consistency_weight=1,
292292
overshoot_consistency_weight=1, edges_consistency_weight=1, force_binary=True):
293293
#accepts 3 maps [h x w] 0-1 range
294-
missed_cc = np.array(missed_cc)
295-
overshoot_cc = np.array(overshoot_cc)
296-
edge_cc = np.array(edge_cc)
294+
missed_cc = np.array(missed_cc)[0]
295+
overshoot_cc = np.array(overshoot_cc)[0]
296+
edge_cc = np.array(edge_cc)[0]
297297
weights = np.ones_like(missed_cc)
298298
weights*=missed_cc.clip(1-missed_consistency_weight,1)
299299
weights*=overshoot_cc.clip(1-overshoot_consistency_weight,1)
@@ -304,4 +304,4 @@ def mix_cc(missed_cc, overshoot_cc, edge_cc, blur=2, dilate=0, missed_consistenc
304304
weights = (1-binary_dilation(1-weights, disk(dilate))).astype('uint8')
305305
if blur>0: weights = scipy.ndimage.gaussian_filter(weights, [blur, blur])
306306

307-
return torch.from_numpy(weights)
307+
return torch.from_numpy(weights)[None,...]

frame_nodes.py

Lines changed: 164 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,127 @@
66
import folder_paths
77
from .frame_utils import FrameDataset, StylizedFrameDataset, get_scheduled_arg, get_size, save_video
88

9+
class ApplyMask:
10+
@classmethod
11+
def INPUT_TYPES(s):
12+
return {
13+
"required": {
14+
"destination": ("IMAGE",),
15+
"source": ("IMAGE",),
16+
},
17+
"optional": {
18+
"mask": ("MASK",),
19+
}
20+
}
21+
RETURN_TYPES = ("IMAGE",)
22+
FUNCTION = "composite"
23+
24+
CATEGORY = "WarpFusion"
25+
26+
def composite(self, destination, source, mask = None):
27+
28+
mask = mask[..., None].repeat(1,1,1,destination.shape[-1])
29+
res = destination*(1-mask) + source*(mask)
30+
return (res,)
31+
32+
class ApplyMaskConditional:
33+
@classmethod
34+
def INPUT_TYPES(s):
35+
return {
36+
"required": {
37+
"destination": ("IMAGE",),
38+
"source": ("IMAGE",),
39+
"current_frame_number": ("INT",),
40+
"apply_at_frames": ("STRING",),
41+
"don_not_apply_at_frames": ("BOOLEAN",),
42+
},
43+
"optional": {
44+
"mask": ("MASK",),
45+
}
46+
}
47+
RETURN_TYPES = ("IMAGE",)
48+
FUNCTION = "composite"
49+
50+
CATEGORY = "WarpFusion"
51+
52+
def composite(self, destination, source, current_frame_number, apply_at_frames, don_not_apply_at_frames, mask = None):
53+
idx_list = [int(i) for i in apply_at_frames.split(',')]
54+
if (current_frame_number not in idx_list) if don_not_apply_at_frames else (current_frame_number in idx_list):
55+
# Convert mask to correct format for interpolation [b,c,h,w]
56+
mask = mask[None,...]
57+
58+
# Resize mask to destination size using explicit dimensions
59+
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[1], destination.shape[2]), mode='bilinear')
60+
61+
# Convert back to [b,h,w,1] format
62+
mask = mask[0,...,None].repeat(1,1,1,destination.shape[-1])
63+
64+
source = source.permute(0,3,1,2)
65+
source = torch.nn.functional.interpolate(source, size=(destination.shape[1], destination.shape[2]), mode='bilinear')
66+
source = source.permute(0,2,3,1)
67+
68+
res = destination*(1-mask) + source*(mask)
69+
return (res,)
70+
else:
71+
return (destination,)
72+
73+
class ApplyMaskLatent:
74+
@classmethod
75+
def INPUT_TYPES(s):
76+
return {
77+
"required": {
78+
"destination": ("LATENT",),
79+
"source": ("LATENT",),
80+
},
81+
"optional": {
82+
"mask": ("MASK",),
83+
}
84+
}
85+
RETURN_TYPES = ("LATENT",)
86+
FUNCTION = "composite"
87+
88+
CATEGORY = "WarpFusion"
89+
90+
def composite(self, destination, source, mask = None):
91+
destination = destination['samples']
92+
source = source['samples']
93+
mask = mask[None, ...]
94+
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[2], destination.shape[3]))
95+
res = destination*(1-mask) + source*(mask)
96+
return ({"samples":res}, )
97+
98+
class ApplyMaskLatentConditional:
99+
@classmethod
100+
def INPUT_TYPES(s):
101+
return {
102+
"required": {
103+
"destination": ("LATENT",),
104+
"source": ("LATENT",),
105+
"current_frame_number": ("INT",),
106+
"apply_at_frames": ("STRING",),
107+
"don_not_apply_at_frames": ("BOOLEAN",),
108+
},
109+
"optional": {
110+
"mask": ("MASK",),
111+
}
112+
}
113+
RETURN_TYPES = ("LATENT",)
114+
FUNCTION = "composite"
115+
116+
CATEGORY = "WarpFusion"
117+
118+
def composite(self, destination, source, current_frame_number, apply_at_frames, don_not_apply_at_frames, mask = None):
119+
destination = destination['samples']
120+
source = source['samples']
121+
idx_list = [int(i) for i in apply_at_frames.split(',')]
122+
if (current_frame_number not in idx_list) if don_not_apply_at_frames else (current_frame_number in idx_list):
123+
mask = mask[None, ...]
124+
mask = torch.nn.functional.interpolate(mask, size=(destination.shape[2], destination.shape[3]))
125+
res = destination*(1-mask) + source*(mask)
126+
return ({"samples":res}, )
127+
else:
128+
return ({"samples":destination}, )
129+
9130
class LoadFrameSequence:
10131
@classmethod
11132
def INPUT_TYPES(self):
@@ -90,26 +211,21 @@ def INPUT_TYPES(self):
90211
"start_frame":("INT", {"default": 0, "min": 0, "max": 9999999999}),
91212
"end_frame":("INT", {"default": -1, "min": -1, "max": 9999999999}),
92213
"nth_frame":("INT", {"default": 1, "min": 1, "max": 9999999999}),
93-
},
214+
"overwrite":("BOOLEAN", {"default": False})
215+
}
94216
}
95217

96218
CATEGORY = "WarpFusion"
97219
RETURN_TYPES = ("FRAME_DATASET", "INT")
98220
RETURN_NAMES = ("FRAME_DATASET", "Total_frames")
99221
FUNCTION = "get_frames"
100222

101-
def get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame):
223+
def get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame, overwrite):
102224
ds = FrameDataset(file_path, outdir_prefix='', videoframes_root=folder_paths.get_output_directory(),
103-
update_on_getitem=update_on_frame_load, start_frame=start_frame, end_frame=end_frame, nth_frame=nth_frame)
225+
update_on_getitem=update_on_frame_load, start_frame=start_frame, end_frame=end_frame, nth_frame=nth_frame, overwrite=overwrite)
226+
if len(ds)==0:
227+
raise Exception(f"Found 0 frames in path {file_path}") #thanks to https://github.com/Aljnk
104228
return (ds,len(ds))
105-
106-
@classmethod
107-
def VALIDATE_INPUTS(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame):
108-
_, n_frames = self.get_frames(self, file_path, update_on_frame_load, start_frame, end_frame, nth_frame)
109-
if n_frames==0:
110-
return f"Found 0 frames in path {file_path}"
111-
112-
return True
113229

114230
class LoadFrameFromFolder:
115231
@classmethod
@@ -308,6 +424,7 @@ def export_video(self, output_dir, frames_input_dir, batch_name, first_frame=1,
308424
print('Exporting video.')
309425
save_video(indir=frames_input_dir, video_out=output_dir, batch_name=batch_name, start_frame=first_frame,
310426
last_frame=last_frame, fps=fps, output_format=output_format, use_deflicker=use_deflicker)
427+
# raise Exception(f'Exported video successfully. This exception is raised to just stop the endless cycle :D.\n you can find your video at {output_dir}')
311428
return ()
312429

313430
class SchedulerInt:
@@ -396,6 +513,30 @@ def INPUT_TYPES(self):
396513
def get_value(self, start, end, current_number):
397514
return (current_number, start, end)
398515

516+
class MakePaths:
517+
@classmethod
518+
def INPUT_TYPES(self):
519+
return {"required": {
520+
"root_path": ("STRING", {"multiline": True, "default": "./"}),
521+
"experiment": ("STRING", {"default": "experiment"}),
522+
"video": ("STRING", {"default": "video"}),
523+
"frames": ("STRING", {"default": "frames"}),
524+
"smoothed": ("STRING", {"default": "smoothed"}),
525+
}}
526+
527+
CATEGORY = "WarpFusion"
528+
RETURN_TYPES = ("STRING", "STRING", "STRING")
529+
RETURN_NAMES = ("video_path", "frames_path", "smoothed_frames_path")
530+
FUNCTION = "build_paths"
531+
532+
def build_paths(self, root_path, experiment, video, frames, smoothed):
533+
base_path = os.path.join(root_path, experiment)
534+
video_path = os.path.join(base_path, video)
535+
frames_path = os.path.join(base_path, frames)
536+
smoothed_frames_path = os.path.join(base_path, smoothed)
537+
538+
return (video_path, frames_path, smoothed_frames_path)
539+
399540
NODE_CLASS_MAPPINGS = {
400541
"LoadFrameSequence": LoadFrameSequence,
401542
"LoadFrame": LoadFrame,
@@ -409,7 +550,12 @@ def get_value(self, start, end, current_number):
409550
"SchedulerString":SchedulerString,
410551
"SchedulerFloat":SchedulerFloat,
411552
"SchedulerInt":SchedulerInt,
412-
"FixedQueue":FixedQueue
553+
"FixedQueue":FixedQueue,
554+
"ApplyMask":ApplyMask,
555+
"ApplyMaskConditional":ApplyMaskConditional,
556+
"ApplyMaskLatent":ApplyMaskLatent,
557+
"ApplyMaskLatentConditional":ApplyMaskLatentConditional,
558+
"MakePaths": MakePaths,
413559
}
414560

415561
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -425,5 +571,10 @@ def get_value(self, start, end, current_number):
425571
"SchedulerString":"SchedulerString",
426572
"SchedulerFloat":"SchedulerFloat",
427573
"SchedulerInt":"SchedulerInt",
428-
"FixedQueue":"FixedQueue"
574+
"FixedQueue":"FixedQueue",
575+
"ApplyMask":"ApplyMask",
576+
"ApplyMaskConditional":"ApplyMaskConditional",
577+
"ApplyMaskLatent":"ApplyMaskLatent",
578+
"ApplyMaskLatentConditional":"ApplyMaskLatentConditional",
579+
"MakePaths": "Make Paths",
429580
}

frame_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):
4545

4646

4747
class FrameDataset():
48-
def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on_getitem=False, start_frame=0, end_frame=-1, nth_frame=1):
48+
def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on_getitem=False, start_frame=0, end_frame=-1, nth_frame=1, overwrite=False):
49+
if outdir_prefix == '':
50+
outdir_prefix = f'{start_frame}_{end_frame}_{nth_frame}'
4951
if end_frame == -1: end_frame = 999999999
5052
self.frame_paths = None
5153
image_extenstions = ['jpeg', 'jpg', 'png', 'tiff', 'bmp', 'webp']
@@ -66,10 +68,16 @@ def __init__(self, source_path, outdir_prefix='', videoframes_root='', update_on
6668
"""if 1 video"""
6769
hash = generate_file_hash(source_path)[:10]
6870
out_path = os.path.join(videoframes_root, outdir_prefix+'_'+hash)
69-
70-
extractFrames(source_path, out_path,
71+
files = glob.glob(os.path.join(out_path, '*.*'))
72+
if len(files)>0 and not overwrite:
73+
self.frame_paths = files
74+
print(f'Found {len(self.frame_paths)} frames in {out_path}. Skipping extraction. Check overwrite option to overwrite.')
75+
return
76+
else:
77+
print(f'Extracting frames from {source_path} to {out_path}')
78+
extractFrames(source_path, out_path,
7179
nth_frame=nth_frame, start_frame=start_frame, end_frame=end_frame)
72-
self.frame_paths = glob.glob(os.path.join(out_path, '*.*')) #dont apply start-end here as already applied during video extraction
80+
self.frame_paths = glob.glob(os.path.join(out_path, '*.*')) #dont apply start-end here as already applied during video extraction
7381
self.source_path = out_path
7482
if len(self.frame_paths)<1:
7583
raise FileNotFoundError(f'Couldn`t extract frames from {source_path}\nPlease specify an existing source path.')

0 commit comments

Comments
 (0)