66import folder_paths
77from .frame_utils import FrameDataset , StylizedFrameDataset , get_scheduled_arg , get_size , save_video
88
9+ class ApplyMask :
10+ @classmethod
11+ def INPUT_TYPES (s ):
12+ return {
13+ "required" : {
14+ "destination" : ("IMAGE" ,),
15+ "source" : ("IMAGE" ,),
16+ },
17+ "optional" : {
18+ "mask" : ("MASK" ,),
19+ }
20+ }
21+ RETURN_TYPES = ("IMAGE" ,)
22+ FUNCTION = "composite"
23+
24+ CATEGORY = "WarpFusion"
25+
26+ def composite (self , destination , source , mask = None ):
27+
28+ mask = mask [..., None ].repeat (1 ,1 ,1 ,destination .shape [- 1 ])
29+ res = destination * (1 - mask ) + source * (mask )
30+ return (res ,)
31+
32+ class ApplyMaskConditional :
33+ @classmethod
34+ def INPUT_TYPES (s ):
35+ return {
36+ "required" : {
37+ "destination" : ("IMAGE" ,),
38+ "source" : ("IMAGE" ,),
39+ "current_frame_number" : ("INT" ,),
40+ "apply_at_frames" : ("STRING" ,),
41+ "don_not_apply_at_frames" : ("BOOLEAN" ,),
42+ },
43+ "optional" : {
44+ "mask" : ("MASK" ,),
45+ }
46+ }
47+ RETURN_TYPES = ("IMAGE" ,)
48+ FUNCTION = "composite"
49+
50+ CATEGORY = "WarpFusion"
51+
52+ def composite (self , destination , source , current_frame_number , apply_at_frames , don_not_apply_at_frames , mask = None ):
53+ idx_list = [int (i ) for i in apply_at_frames .split (',' )]
54+ if (current_frame_number not in idx_list ) if don_not_apply_at_frames else (current_frame_number in idx_list ):
55+ # Convert mask to correct format for interpolation [b,c,h,w]
56+ mask = mask [None ,...]
57+
58+ # Resize mask to destination size using explicit dimensions
59+ mask = torch .nn .functional .interpolate (mask , size = (destination .shape [1 ], destination .shape [2 ]), mode = 'bilinear' )
60+
61+ # Convert back to [b,h,w,1] format
62+ mask = mask [0 ,...,None ].repeat (1 ,1 ,1 ,destination .shape [- 1 ])
63+
64+ source = source .permute (0 ,3 ,1 ,2 )
65+ source = torch .nn .functional .interpolate (source , size = (destination .shape [1 ], destination .shape [2 ]), mode = 'bilinear' )
66+ source = source .permute (0 ,2 ,3 ,1 )
67+
68+ res = destination * (1 - mask ) + source * (mask )
69+ return (res ,)
70+ else :
71+ return (destination ,)
72+
73+ class ApplyMaskLatent :
74+ @classmethod
75+ def INPUT_TYPES (s ):
76+ return {
77+ "required" : {
78+ "destination" : ("LATENT" ,),
79+ "source" : ("LATENT" ,),
80+ },
81+ "optional" : {
82+ "mask" : ("MASK" ,),
83+ }
84+ }
85+ RETURN_TYPES = ("LATENT" ,)
86+ FUNCTION = "composite"
87+
88+ CATEGORY = "WarpFusion"
89+
90+ def composite (self , destination , source , mask = None ):
91+ destination = destination ['samples' ]
92+ source = source ['samples' ]
93+ mask = mask [None , ...]
94+ mask = torch .nn .functional .interpolate (mask , size = (destination .shape [2 ], destination .shape [3 ]))
95+ res = destination * (1 - mask ) + source * (mask )
96+ return ({"samples" :res }, )
97+
98+ class ApplyMaskLatentConditional :
99+ @classmethod
100+ def INPUT_TYPES (s ):
101+ return {
102+ "required" : {
103+ "destination" : ("LATENT" ,),
104+ "source" : ("LATENT" ,),
105+ "current_frame_number" : ("INT" ,),
106+ "apply_at_frames" : ("STRING" ,),
107+ "don_not_apply_at_frames" : ("BOOLEAN" ,),
108+ },
109+ "optional" : {
110+ "mask" : ("MASK" ,),
111+ }
112+ }
113+ RETURN_TYPES = ("LATENT" ,)
114+ FUNCTION = "composite"
115+
116+ CATEGORY = "WarpFusion"
117+
118+ def composite (self , destination , source , current_frame_number , apply_at_frames , don_not_apply_at_frames , mask = None ):
119+ destination = destination ['samples' ]
120+ source = source ['samples' ]
121+ idx_list = [int (i ) for i in apply_at_frames .split (',' )]
122+ if (current_frame_number not in idx_list ) if don_not_apply_at_frames else (current_frame_number in idx_list ):
123+ mask = mask [None , ...]
124+ mask = torch .nn .functional .interpolate (mask , size = (destination .shape [2 ], destination .shape [3 ]))
125+ res = destination * (1 - mask ) + source * (mask )
126+ return ({"samples" :res }, )
127+ else :
128+ return ({"samples" :destination }, )
129+
9130class LoadFrameSequence :
10131 @classmethod
11132 def INPUT_TYPES (self ):
@@ -90,26 +211,21 @@ def INPUT_TYPES(self):
90211 "start_frame" :("INT" , {"default" : 0 , "min" : 0 , "max" : 9999999999 }),
91212 "end_frame" :("INT" , {"default" : - 1 , "min" : - 1 , "max" : 9999999999 }),
92213 "nth_frame" :("INT" , {"default" : 1 , "min" : 1 , "max" : 9999999999 }),
93- },
214+ "overwrite" :("BOOLEAN" , {"default" : False })
215+ }
94216 }
95217
96218 CATEGORY = "WarpFusion"
97219 RETURN_TYPES = ("FRAME_DATASET" , "INT" )
98220 RETURN_NAMES = ("FRAME_DATASET" , "Total_frames" )
99221 FUNCTION = "get_frames"
100222
101- def get_frames (self , file_path , update_on_frame_load , start_frame , end_frame , nth_frame ):
223+ def get_frames (self , file_path , update_on_frame_load , start_frame , end_frame , nth_frame , overwrite ):
102224 ds = FrameDataset (file_path , outdir_prefix = '' , videoframes_root = folder_paths .get_output_directory (),
103- update_on_getitem = update_on_frame_load , start_frame = start_frame , end_frame = end_frame , nth_frame = nth_frame )
225+ update_on_getitem = update_on_frame_load , start_frame = start_frame , end_frame = end_frame , nth_frame = nth_frame , overwrite = overwrite )
226+ if len (ds )== 0 :
227+ raise Exception (f"Found 0 frames in path { file_path } " ) #thanks to https://github.com/Aljnk
104228 return (ds ,len (ds ))
105-
106- @classmethod
107- def VALIDATE_INPUTS (self , file_path , update_on_frame_load , start_frame , end_frame , nth_frame ):
108- _ , n_frames = self .get_frames (self , file_path , update_on_frame_load , start_frame , end_frame , nth_frame )
109- if n_frames == 0 :
110- return f"Found 0 frames in path { file_path } "
111-
112- return True
113229
114230class LoadFrameFromFolder :
115231 @classmethod
@@ -308,6 +424,7 @@ def export_video(self, output_dir, frames_input_dir, batch_name, first_frame=1,
308424 print ('Exporting video.' )
309425 save_video (indir = frames_input_dir , video_out = output_dir , batch_name = batch_name , start_frame = first_frame ,
310426 last_frame = last_frame , fps = fps , output_format = output_format , use_deflicker = use_deflicker )
427+ # raise Exception(f'Exported video successfully. This exception is raised to just stop the endless cycle :D.\n you can find your video at {output_dir}')
311428 return ()
312429
313430class SchedulerInt :
@@ -396,6 +513,30 @@ def INPUT_TYPES(self):
396513 def get_value (self , start , end , current_number ):
397514 return (current_number , start , end )
398515
516+ class MakePaths :
517+ @classmethod
518+ def INPUT_TYPES (self ):
519+ return {"required" : {
520+ "root_path" : ("STRING" , {"multiline" : True , "default" : "./" }),
521+ "experiment" : ("STRING" , {"default" : "experiment" }),
522+ "video" : ("STRING" , {"default" : "video" }),
523+ "frames" : ("STRING" , {"default" : "frames" }),
524+ "smoothed" : ("STRING" , {"default" : "smoothed" }),
525+ }}
526+
527+ CATEGORY = "WarpFusion"
528+ RETURN_TYPES = ("STRING" , "STRING" , "STRING" )
529+ RETURN_NAMES = ("video_path" , "frames_path" , "smoothed_frames_path" )
530+ FUNCTION = "build_paths"
531+
532+ def build_paths (self , root_path , experiment , video , frames , smoothed ):
533+ base_path = os .path .join (root_path , experiment )
534+ video_path = os .path .join (base_path , video )
535+ frames_path = os .path .join (base_path , frames )
536+ smoothed_frames_path = os .path .join (base_path , smoothed )
537+
538+ return (video_path , frames_path , smoothed_frames_path )
539+
399540NODE_CLASS_MAPPINGS = {
400541 "LoadFrameSequence" : LoadFrameSequence ,
401542 "LoadFrame" : LoadFrame ,
@@ -409,7 +550,12 @@ def get_value(self, start, end, current_number):
409550 "SchedulerString" :SchedulerString ,
410551 "SchedulerFloat" :SchedulerFloat ,
411552 "SchedulerInt" :SchedulerInt ,
412- "FixedQueue" :FixedQueue
553+ "FixedQueue" :FixedQueue ,
554+ "ApplyMask" :ApplyMask ,
555+ "ApplyMaskConditional" :ApplyMaskConditional ,
556+ "ApplyMaskLatent" :ApplyMaskLatent ,
557+ "ApplyMaskLatentConditional" :ApplyMaskLatentConditional ,
558+ "MakePaths" : MakePaths ,
413559}
414560
415561NODE_DISPLAY_NAME_MAPPINGS = {
@@ -425,5 +571,10 @@ def get_value(self, start, end, current_number):
425571 "SchedulerString" :"SchedulerString" ,
426572 "SchedulerFloat" :"SchedulerFloat" ,
427573 "SchedulerInt" :"SchedulerInt" ,
428- "FixedQueue" :"FixedQueue"
574+ "FixedQueue" :"FixedQueue" ,
575+ "ApplyMask" :"ApplyMask" ,
576+ "ApplyMaskConditional" :"ApplyMaskConditional" ,
577+ "ApplyMaskLatent" :"ApplyMaskLatent" ,
578+ "ApplyMaskLatentConditional" :"ApplyMaskLatentConditional" ,
579+ "MakePaths" : "Make Paths" ,
429580}
0 commit comments