@@ -74,6 +74,25 @@ class VideoSegmentationSam3(desc.Node):
7474 description = "Weights file for the segmentation model." ,
7575 value = "${RDS_SAM3_MODEL_PATH}" ,
7676 ),
77+ desc .BoolParam (
78+ name = "combineFwdAndBwdSeg" ,
79+ label = "Combine Forward and Backward Segmentation" ,
80+ description = "Launch segmentation in both forward and backward directions and combine masks." ,
81+ value = False ,
82+ ),
83+ desc .BoolParam (
84+ name = "timeSlicing" ,
85+ label = "Time Slicing" ,
86+ description = "Enable time slicing by adding text prompt every N frames and by propagating forward on N frames." ,
87+ value = False ,
88+ ),
89+ desc .IntParam (
90+ name = "sliceSize" ,
91+ label = "Slice Size" ,
92+ description = "Number of frames on which the mask is propagated." ,
93+ value = 16 ,
94+ enabled = lambda node : node .timeSlicing .value ,
95+ ),
7796 desc .BoolParam (
7897 name = "maskInvert" ,
7998 label = "Invert Masks" ,
@@ -195,13 +214,16 @@ def prepare_masks_for_visualization(self, frame_to_output):
195214 frame_to_output [frame_idx ] = _processed_out
196215 return frame_to_output
197216
198- def propagate_in_video (self , predictor , session_id ):
217+ def propagate_in_video (self , predictor , session_id , start_frame_idx = None , max_frame_num_to_track = None , direction = "both" ):
199218 # we will just propagate from frame 0 to the end of the video
200219 outputs_per_frame = {}
201220 for response in predictor .handle_stream_request (
202221 request = dict (
203222 type = "propagate_in_video" ,
204223 session_id = session_id ,
224+ propagation_direction = direction ,
225+ start_frame_idx = start_frame_idx ,
226+ max_frame_num_to_track = max_frame_num_to_track ,
205227 )
206228 ):
207229 outputs_per_frame [response ["frame_index" ]] = response ["outputs" ]
@@ -339,6 +361,25 @@ def processChunk(self, chunk):
339361
340362 colorPalette = image .paletteGenerator ()
341363 firstFrameId = chunk_image_paths [0 ][2 ]
364+ frameNumber = len (chunk_image_paths )
365+ frameIdxToTextPrompt_fwd = [0 ]
366+ frameIdxToTextPrompt_bwd = [frameNumber - 1 ]
367+ max_frame_num_to_track_fwd = None
368+ if chunk .node .timeSlicing .value :
369+ if chunk .node .sliceSize .value > 0 and chunk .node .sliceSize .value <= frameNumber :
370+ currFrameToTextPrompt_fwd = 0
371+ currFrameToTextPrompt_bwd = frameNumber - 1
372+ max_frame_num_to_track_fwd = chunk .node .sliceSize .value - 1
373+ max_frame_num_to_track_bwd = chunk .node .sliceSize .value
374+ while currFrameToTextPrompt_fwd + chunk .node .sliceSize .value < frameNumber :
375+ currFrameToTextPrompt_fwd += chunk .node .sliceSize .value
376+ frameIdxToTextPrompt_fwd .append (currFrameToTextPrompt_fwd )
377+ while currFrameToTextPrompt_bwd - chunk .node .sliceSize .value >= 0 :
378+ currFrameToTextPrompt_bwd -= chunk .node .sliceSize .value
379+ frameIdxToTextPrompt_bwd .append (currFrameToTextPrompt_bwd )
380+
381+ logger .debug (f"frameIdxToTextPromptFwd: { frameIdxToTextPrompt_fwd } " )
382+ logger .debug (f"frameIdxToTextPromptBwd: { frameIdxToTextPrompt_bwd } " )
342383
343384 for idx , path in enumerate (chunk_image_paths ):
344385 img , h_ori , w_ori , PAR , orientation = image .loadImage (str (chunk_image_paths [idx ][0 ]), True )
@@ -402,48 +443,76 @@ def processChunk(self, chunk):
402443 )
403444 session_id = response ["session_id" ]
404445
405- video_predictor .handle_request (
406- request = dict (
407- type = "add_prompt" ,
408- session_id = session_id ,
409- frame_index = 0 ,
410- text = chunk .node .prompt .value ,
411- )
412- )
413-
414- for f , bbox in bboxes .items ():
446+ # for f, bbox in bboxes.items():
447+ # video_predictor.handle_request(
448+ # request=dict(
449+ # type="add_prompt",
450+ # session_id=session_id,
451+ # frame_index=f,
452+ # bounding_boxes=bbox[0],
453+ # bounding_box_labels=bbox[1],
454+ # )
455+ # )
456+
457+ outputs_per_frame_fwd = {}
458+ for n , fIdx in enumerate (frameIdxToTextPrompt_fwd ):
415459 video_predictor .handle_request (
416460 request = dict (
417461 type = "add_prompt" ,
418462 session_id = session_id ,
419- frame_index = f ,
420- bounding_boxes = bbox [0 ],
421- bounding_box_labels = bbox [1 ],
463+ frame_index = fIdx ,
464+ text = chunk .node .prompt .value ,
422465 )
423466 )
467+ outputs_per_frame_curr_fwd = self .propagate_in_video (video_predictor , session_id , fIdx , max_frame_num_to_track_fwd , "forward" )
468+ #logger.debug(f"{outputs_per_frame_curr_fwd.keys()}")
469+ outputs_per_frame_fwd .update (outputs_per_frame_curr_fwd )
470+
471+ logger .debug (f"Fwd keys: { outputs_per_frame_fwd .keys ()} " )
424472
425- self . propagate_in_video ( video_predictor , session_id )
473+ video_predictor . handle_request ( request = dict ( type = "reset_session" , session_id = session_id ) )
426474
427- for f , objects in clicks .items ():
428- for obj_id , obj in objects .items ():
475+ outputs_per_frame_bwd = {}
476+ if chunk .node .combineFwdAndBwdSeg .value :
477+ for n , fIdx in enumerate (frameIdxToTextPrompt_bwd ):
429478 video_predictor .handle_request (
430479 request = dict (
431480 type = "add_prompt" ,
432481 session_id = session_id ,
433- frame_index = f ,
434- points = torch .tensor (np .array (obj [0 ])),
435- point_labels = torch .tensor (np .array (obj [1 ])),
436- obj_id = obj_id
482+ frame_index = fIdx ,
483+ text = chunk .node .prompt .value ,
437484 )
438485 )
486+ outputs_per_frame_curr_bwd = self .propagate_in_video (video_predictor , session_id , fIdx , max_frame_num_to_track_bwd , "backward" )
487+ #logger.debug(f"{outputs_per_frame_curr_bwd.keys()}")
488+ outputs_per_frame_bwd .update (outputs_per_frame_curr_bwd )
489+ logger .debug (f"Bwd keys: { outputs_per_frame_bwd .keys ()} " )
490+
491+ #outputs_per_frame = {}
492+
493+ # for f, objects in clicks.items():
494+ # for obj_id, obj in objects.items():
495+ # video_predictor.handle_request(
496+ # request=dict(
497+ # type="add_prompt",
498+ # session_id=session_id,
499+ # frame_index=f,
500+ # points=torch.tensor(np.array(obj[0])),
501+ # point_labels=torch.tensor(np.array(obj[1])),
502+ # obj_id=obj_id
503+ # )
504+ # )
505+
506+ #outputs_per_frame = self.propagate_in_video(video_predictor, session_id, f, None)
439507
440- outputs_per_frame = self .propagate_in_video (video_predictor , session_id )
508+ # outputs_per_frame = self.propagate_in_video(video_predictor, session_id)
441509
442- outputs_per_frame = self .prepare_masks_for_visualization (outputs_per_frame )
510+ outputs_per_frame_fwd = self .prepare_masks_for_visualization (outputs_per_frame_fwd )
511+ outputs_per_frame_bwd = self .prepare_masks_for_visualization (outputs_per_frame_bwd )
443512
444513 video_predictor .handle_request (request = dict (type = "close_session" , session_id = session_id ))
445514
446- for frameId , masks in outputs_per_frame .items ():
515+ for frameId , masks in outputs_per_frame_fwd .items ():
447516 maskImage = np .zeros_like (img )
448517 colorMaskImage = np .zeros_like (img )
449518 if chunk .node .outputCryptomatte .value :
@@ -465,6 +534,9 @@ def processChunk(self, chunk):
465534 manifest [obj_name ] = hex_val
466535 crypto_id [mask ] = f32_hash
467536 crypto_cov [mask ] = 1.0
537+ if frameId in outputs_per_frame_bwd .keys ():
538+ for key , mask in outputs_per_frame_bwd [frameId ].items ():
539+ maskImage [mask ] = [255 , 255 , 255 ]
468540
469541 if chunk .node .outputCryptomatte .value :
470542 spec = oiio .ImageSpec (img .shape [1 ], img .shape [0 ], 7 , oiio .FLOAT )
0 commit comments