Skip to content

Commit d1f2c37

Browse files
committed
SDMatte trimap update
1 parent 81bb90b commit d1f2c37

1 file changed

Lines changed: 13 additions & 4 deletions

File tree

meshroom/imageSegmentation/SDMatte.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -340,15 +340,18 @@ def build_SDMatte_model(self, modelFolder, checkpoint, device, promptType):
340340
conv_scale=3,
341341
num_inference_steps=1,
342342
aux_input=promptType,
343+
aux_input_list=["point_mask", "bbox_mask", "mask", "trimap"],
344+
attn_mask_aux_input=["point_mask", "bbox_mask", "mask", "trimap"],
343345
add_noise=False,
344346
use_dis_loss=True,
345347
use_aux_input=True,
346348
use_coor_input=True,
347349
use_attention_mask=True,
350+
use_encoder_attention_mask=True,
348351
residual_connection=False,
349352
use_encoder_hidden_states=True,
350353
use_attention_mask_list=[True, True, True],
351-
use_encoder_hidden_states_list=[False, True, False],
354+
use_encoder_hidden_states_list=[True, True, True],
352355
)
353356
model.to(device)
354357
DetectionCheckpointer(model).load(checkpoint)
@@ -421,6 +424,7 @@ def processChunk(self, chunk):
421424
if promptType == "":
422425
raise ValueError("Some images have no valid prompt to drive the matting process !!!")
423426
else:
427+
logger.info(f"prompt type: {promptType}")
424428

425429
if not os.path.exists(chunk.node.output.value):
426430
os.mkdir(chunk.node.output.value)
@@ -465,9 +469,14 @@ def processChunk(self, chunk):
465469
mask = maskRGB[:,:,0]
466470
mask_sized = cv2.resize(mask, inference_size, interpolation=cv2.INTER_NEAREST)
467471
mask_scaled = mask_sized.copy() * 2 - 1
468-
sample["mask"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
469-
sample["mask_coords"] = np.array([0, 0, 1, 1])
470-
sample["mask_coords"] = torch.from_numpy(sample["mask_coords"]).float().unsqueeze(0)
472+
if promptType == "mask":
473+
sample["mask"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
474+
sample["mask_coords"] = np.array([0, 0, 1, 1])
475+
sample["mask_coords"] = torch.from_numpy(sample["mask_coords"]).float().unsqueeze(0)
476+
else:
477+
sample["trimap"] = F.to_tensor(mask_scaled).float().unsqueeze(0)
478+
sample["trimap_coords"] = np.array([0, 0, 1, 1])
479+
sample["trimap_coords"] = torch.from_numpy(sample["trimap_coords"]).float().unsqueeze(0)
471480
elif promptType == "auto_mask":
472481
mask = np.ones_like(img)[:,:,0]
473482
mask_sized = cv2.resize(mask, inference_size, interpolation=cv2.INTER_NEAREST)

0 commit comments

Comments
 (0)