Skip to content

Commit 68a22a3

Browse files
authored
Merge pull request #41 from meshroomHub/dev/sam3VideoUpdatePropagate
Add VideoSegmentationSam3Text node
2 parents 9355410 + edcba97 commit 68a22a3

2 files changed

Lines changed: 614 additions & 24 deletions

File tree

meshroom/imageSegmentation/VideoSegmentationSam3.py

Lines changed: 96 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ class VideoSegmentationSam3(desc.Node):
7474
description="Weights file for the segmentation model.",
7575
value="${RDS_SAM3_MODEL_PATH}",
7676
),
77+
desc.BoolParam(
78+
name="combineFwdAndBwdSeg",
79+
label="Combine Forward and Backward Segmentation",
80+
description="Launch segmentation in both forward and backward directions and combine masks.",
81+
value=False,
82+
),
83+
desc.BoolParam(
84+
name="timeSlicing",
85+
label="Time Slicing",
86+
description="Enable time slicing by adding text prompt every N frames and by propagating forward on N frames.",
87+
value=False,
88+
),
89+
desc.IntParam(
90+
name="sliceSize",
91+
label="Slice Size",
92+
description="Number of frames on which the mask is propagated.",
93+
value=16,
94+
enabled=lambda node: node.timeSlicing.value,
95+
),
7796
desc.BoolParam(
7897
name="maskInvert",
7998
label="Invert Masks",
@@ -195,13 +214,16 @@ def prepare_masks_for_visualization(self, frame_to_output):
195214
frame_to_output[frame_idx] = _processed_out
196215
return frame_to_output
197216

198-
def propagate_in_video(self, predictor, session_id):
217+
def propagate_in_video(self, predictor, session_id, start_frame_idx=None, max_frame_num_to_track=None, direction="both"):
199218
# we will just propagate from frame 0 to the end of the video
200219
outputs_per_frame = {}
201220
for response in predictor.handle_stream_request(
202221
request=dict(
203222
type="propagate_in_video",
204223
session_id=session_id,
224+
propagation_direction=direction,
225+
start_frame_idx=start_frame_idx,
226+
max_frame_num_to_track=max_frame_num_to_track,
205227
)
206228
):
207229
outputs_per_frame[response["frame_index"]] = response["outputs"]
@@ -339,6 +361,25 @@ def processChunk(self, chunk):
339361

340362
colorPalette = image.paletteGenerator()
341363
firstFrameId = chunk_image_paths[0][2]
364+
frameNumber = len(chunk_image_paths)
365+
frameIdxToTextPrompt_fwd = [0]
366+
frameIdxToTextPrompt_bwd = [frameNumber - 1]
367+
max_frame_num_to_track_fwd = None
368+
if chunk.node.timeSlicing.value:
369+
if chunk.node.sliceSize.value > 0 and chunk.node.sliceSize.value <= frameNumber:
370+
currFrameToTextPrompt_fwd = 0
371+
currFrameToTextPrompt_bwd = frameNumber - 1
372+
max_frame_num_to_track_fwd = chunk.node.sliceSize.value - 1
373+
max_frame_num_to_track_bwd = chunk.node.sliceSize.value
374+
while currFrameToTextPrompt_fwd + chunk.node.sliceSize.value < frameNumber:
375+
currFrameToTextPrompt_fwd += chunk.node.sliceSize.value
376+
frameIdxToTextPrompt_fwd.append(currFrameToTextPrompt_fwd)
377+
while currFrameToTextPrompt_bwd - chunk.node.sliceSize.value >= 0:
378+
currFrameToTextPrompt_bwd -= chunk.node.sliceSize.value
379+
frameIdxToTextPrompt_bwd.append(currFrameToTextPrompt_bwd)
380+
381+
logger.debug(f"frameIdxToTextPromptFwd: {frameIdxToTextPrompt_fwd}")
382+
logger.debug(f"frameIdxToTextPromptBwd: {frameIdxToTextPrompt_bwd}")
342383

343384
for idx, path in enumerate(chunk_image_paths):
344385
img, h_ori, w_ori, PAR, orientation = image.loadImage(str(chunk_image_paths[idx][0]), True)
@@ -402,48 +443,76 @@ def processChunk(self, chunk):
402443
)
403444
session_id = response["session_id"]
404445

405-
video_predictor.handle_request(
406-
request=dict(
407-
type="add_prompt",
408-
session_id=session_id,
409-
frame_index=0,
410-
text=chunk.node.prompt.value,
411-
)
412-
)
413-
414-
for f, bbox in bboxes.items():
446+
# for f, bbox in bboxes.items():
447+
# video_predictor.handle_request(
448+
# request=dict(
449+
# type="add_prompt",
450+
# session_id=session_id,
451+
# frame_index=f,
452+
# bounding_boxes=bbox[0],
453+
# bounding_box_labels=bbox[1],
454+
# )
455+
# )
456+
457+
outputs_per_frame_fwd = {}
458+
for n, fIdx in enumerate(frameIdxToTextPrompt_fwd):
415459
video_predictor.handle_request(
416460
request=dict(
417461
type="add_prompt",
418462
session_id=session_id,
419-
frame_index=f,
420-
bounding_boxes=bbox[0],
421-
bounding_box_labels=bbox[1],
463+
frame_index=fIdx,
464+
text=chunk.node.prompt.value,
422465
)
423466
)
467+
outputs_per_frame_curr_fwd = self.propagate_in_video(video_predictor, session_id, fIdx, max_frame_num_to_track_fwd, "forward")
468+
#logger.debug(f"{outputs_per_frame_curr_fwd.keys()}")
469+
outputs_per_frame_fwd.update(outputs_per_frame_curr_fwd)
470+
471+
logger.debug(f"Fwd keys: {outputs_per_frame_fwd.keys()}")
424472

425-
self.propagate_in_video(video_predictor, session_id)
473+
video_predictor.handle_request(request=dict(type="reset_session", session_id=session_id))
426474

427-
for f, objects in clicks.items():
428-
for obj_id, obj in objects.items():
475+
outputs_per_frame_bwd = {}
476+
if chunk.node.combineFwdAndBwdSeg.value:
477+
for n, fIdx in enumerate(frameIdxToTextPrompt_bwd):
429478
video_predictor.handle_request(
430479
request=dict(
431480
type="add_prompt",
432481
session_id=session_id,
433-
frame_index=f,
434-
points=torch.tensor(np.array(obj[0])),
435-
point_labels=torch.tensor(np.array(obj[1])),
436-
obj_id=obj_id
482+
frame_index=fIdx,
483+
text=chunk.node.prompt.value,
437484
)
438485
)
486+
outputs_per_frame_curr_bwd = self.propagate_in_video(video_predictor, session_id, fIdx, max_frame_num_to_track_bwd, "backward")
487+
#logger.debug(f"{outputs_per_frame_curr_bwd.keys()}")
488+
outputs_per_frame_bwd.update(outputs_per_frame_curr_bwd)
489+
logger.debug(f"Bwd keys: {outputs_per_frame_bwd.keys()}")
490+
491+
#outputs_per_frame = {}
492+
493+
# for f, objects in clicks.items():
494+
# for obj_id, obj in objects.items():
495+
# video_predictor.handle_request(
496+
# request=dict(
497+
# type="add_prompt",
498+
# session_id=session_id,
499+
# frame_index=f,
500+
# points=torch.tensor(np.array(obj[0])),
501+
# point_labels=torch.tensor(np.array(obj[1])),
502+
# obj_id=obj_id
503+
# )
504+
# )
505+
506+
#outputs_per_frame = self.propagate_in_video(video_predictor, session_id, f, None)
439507

440-
outputs_per_frame = self.propagate_in_video(video_predictor, session_id)
508+
#outputs_per_frame = self.propagate_in_video(video_predictor, session_id)
441509

442-
outputs_per_frame = self.prepare_masks_for_visualization(outputs_per_frame)
510+
outputs_per_frame_fwd = self.prepare_masks_for_visualization(outputs_per_frame_fwd)
511+
outputs_per_frame_bwd = self.prepare_masks_for_visualization(outputs_per_frame_bwd)
443512

444513
video_predictor.handle_request(request=dict(type="close_session", session_id=session_id))
445514

446-
for frameId, masks in outputs_per_frame.items():
515+
for frameId, masks in outputs_per_frame_fwd.items():
447516
maskImage = np.zeros_like(img)
448517
colorMaskImage = np.zeros_like(img)
449518
if chunk.node.outputCryptomatte.value:
@@ -465,6 +534,9 @@ def processChunk(self, chunk):
465534
manifest[obj_name] = hex_val
466535
crypto_id[mask] = f32_hash
467536
crypto_cov[mask] = 1.0
537+
if frameId in outputs_per_frame_bwd.keys():
538+
for key, mask in outputs_per_frame_bwd[frameId].items():
539+
maskImage[mask] = [255, 255, 255]
468540

469541
if chunk.node.outputCryptomatte.value:
470542
spec = oiio.ImageSpec(img.shape[1], img.shape[0], 7, oiio.FLOAT)

0 commit comments

Comments
 (0)