Skip to content

Commit 9355410

Browse files
authored
Merge pull request #40 from meshroomHub/bugfix/sam3VideoNodeSizeAndFrameId
Bugfix SAM3D and SDMatte
2 parents 07ecbaa + 42f5ede commit 9355410

2 files changed

Lines changed: 35 additions & 6 deletions

File tree

meshroom/imageSegmentation/SDMatte.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,15 +340,18 @@ def build_SDMatte_model(self, modelFolder, checkpoint, device, promptType):
340340
conv_scale=3,
341341
num_inference_steps=1,
342342
aux_input=promptType,
343+
aux_input_list=["point_mask", "bbox_mask", "mask", "trimap"],
344+
attn_mask_aux_input=["point_mask", "bbox_mask", "mask", "trimap"],
343345
add_noise=False,
344346
use_dis_loss=True,
345347
use_aux_input=True,
346348
use_coor_input=True,
347349
use_attention_mask=True,
350+
use_encoder_attention_mask=True,
348351
residual_connection=False,
349352
use_encoder_hidden_states=True,
350353
use_attention_mask_list=[True, True, True],
351-
use_encoder_hidden_states_list=[False, True, False],
354+
use_encoder_hidden_states_list=[True, True, True],
352355
)
353356
model.to(device)
354357
DetectionCheckpointer(model).load(checkpoint)
@@ -421,6 +424,7 @@ def processChunk(self, chunk):
421424
if promptType == "":
422425
raise ValueError("Some images have no valid prompt to drive the matting process !!!")
423426
else:
427+
logger.info(f"prompt type: {promptType}")
424428

425429
if not os.path.exists(chunk.node.output.value):
426430
os.mkdir(chunk.node.output.value)
@@ -465,9 +469,14 @@ def processChunk(self, chunk):
465469
mask = maskRGB[:,:,0]
466470
mask_sized = cv2.resize(mask, inference_size, interpolation=cv2.INTER_NEAREST)
467471
mask_scaled = mask_sized.copy() * 2 - 1
468-
sample["mask"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
469-
sample["mask_coords"] = np.array([0, 0, 1, 1])
470-
sample["mask_coords"] = torch.from_numpy(sample["mask_coords"]).float().unsqueeze(0)
472+
if promptType == "mask":
473+
sample["mask"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
474+
sample["mask_coords"] = np.array([0, 0, 1, 1])
475+
sample["mask_coords"] = torch.from_numpy(sample["mask_coords"]).float().unsqueeze(0)
476+
else:
477+
sample["trimap"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
478+
sample["trimap_coords"] = np.array([0, 0, 1, 1])
479+
sample["trimap_coords"] = torch.from_numpy(sample["trimap_coords"]).float().unsqueeze(0)
471480
elif promptType == "auto_mask":
472481
mask = np.ones_like(img)[:,:,0]
473482
mask_sized = cv2.resize(mask, inference_size, interpolation=cv2.INTER_NEAREST)

meshroom/imageSegmentation/VideoSegmentationSam3.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,28 @@
1212

1313
class Sam3VideoNodeSize(desc.MultiDynamicNodeSize):
1414
def computeSize(self, node):
15+
if node.attribute(self._params[0]).isLink:
16+
return node.attribute(self._params[0]).inputLink.node.size
17+
18+
from pathlib import Path
19+
20+
input_path_param = node.attribute(self._params[0])
21+
extension_param = node.attribute(self._params[1])
22+
input_path = input_path_param.value
23+
extension = extension_param.value
24+
include_suffixes = [extension.lower(), extension.upper()]
25+
1526
size = 1
27+
if Path(input_path).is_dir():
28+
import itertools
29+
image_paths = list(itertools.chain(*(Path(input_path).glob(f'*.{suffix}') for suffix in include_suffixes)))
30+
size = len(image_paths)
31+
1632
return size
1733

1834
class VideoSegmentationSam3(desc.Node):
1935
size = Sam3VideoNodeSize(['input', 'extensionIn'])
20-
gpu = desc.Level.INTENSIVE
36+
gpu = desc.Level.EXTREME
2137

2238
category = "Utils"
2339
documentation = """
@@ -322,14 +338,18 @@ def processChunk(self, chunk):
322338
bboxes = {}
323339

324340
colorPalette = image.paletteGenerator()
341+
firstFrameId = chunk_image_paths[0][2]
325342

326343
for idx, path in enumerate(chunk_image_paths):
327344
img, h_ori, w_ori, PAR, orientation = image.loadImage(str(chunk_image_paths[idx][0]), True)
328345
pil_images.append(Image.fromarray((255.0*img).astype("uint8")))
329346
sourceInfo = {"h_ori": h_ori, "w_ori": w_ori, "PAR": PAR, "orientation": orientation}
330347

331348
viewId = chunk_image_paths[idx][1]
332-
frameId = chunk_image_paths[idx][2]
349+
if firstFrameId is None or chunk_image_paths[idx][2] is None:
350+
frameId = idx
351+
else:
352+
frameId = chunk_image_paths[idx][2] - firstFrameId
333353

334354
objects = {}
335355
if viewId is not None and viewId in posClickDictFromShape:

0 commit comments

Comments
 (0)