@@ -340,15 +340,18 @@ def build_SDMatte_model(self, modelFolder, checkpoint, device, promptType):
340340 conv_scale = 3 ,
341341 num_inference_steps = 1 ,
342342 aux_input = promptType ,
343+ aux_input_list = ["point_mask" , "bbox_mask" , "mask" , "trimap" ],
344+ attn_mask_aux_input = ["point_mask" , "bbox_mask" , "mask" , "trimap" ],
343345 add_noise = False ,
344346 use_dis_loss = True ,
345347 use_aux_input = True ,
346348 use_coor_input = True ,
347349 use_attention_mask = True ,
350+ use_encoder_attention_mask = True ,
348351 residual_connection = False ,
349352 use_encoder_hidden_states = True ,
350353 use_attention_mask_list = [True , True , True ],
351- use_encoder_hidden_states_list = [False , True , False ],
354+ use_encoder_hidden_states_list = [True , True , True ],
352355 )
353356 model .to (device )
354357 DetectionCheckpointer (model ).load (checkpoint )
@@ -421,6 +424,7 @@ def processChunk(self, chunk):
421424 if promptType == "" :
422425 raise ValueError ("Some images have no valid prompt to drive the matting process !!!" )
423426 else :
427+ logger .info (f"prompt type: { promptType } " )
424428
425429 if not os .path .exists (chunk .node .output .value ):
426430 os .mkdir (chunk .node .output .value )
@@ -465,9 +469,14 @@ def processChunk(self, chunk):
465469 mask = maskRGB [:,:,0 ]
466470 mask_sized = cv2 .resize (mask , inference_size , interpolation = cv2 .INTER_NEAREST )
467471 mask_scaled = mask_sized .copy () * 2 - 1
468- sample ["mask" ] = F .to_tensor (mask_scaled ).float ().unsqueeze (0 )
469- sample ["mask_coords" ] = np .array ([0 , 0 , 1 , 1 ])
470- sample ["mask_coords" ] = torch .from_numpy (sample ["mask_coords" ]).float ().unsqueeze (0 )
472+ if promptType == "mask" :
473+ sample ["mask" ] = F .to_tensor (mask_scaled ).float ().unsqueeze (0 )
474+ sample ["mask_coords" ] = np .array ([0 , 0 , 1 , 1 ])
475+ sample ["mask_coords" ] = torch .from_numpy (sample ["mask_coords" ]).float ().unsqueeze (0 )
476+ else :
477+ sample ["trimap" ] = F .to_tensor (mask_scaled ).float ().unsqueeze (0 )
478+ sample ["trimap_coords" ] = np .array ([0 , 0 , 1 , 1 ])
479+ sample ["trimap_coords" ] = torch .from_numpy (sample ["trimap_coords" ]).float ().unsqueeze (0 )
471480 elif promptType == "auto_mask" :
472481 mask = np .ones_like (img )[:,:,0 ]
473482 mask_sized = cv2 .resize (mask , inference_size , interpolation = cv2 .INTER_NEAREST )
0 commit comments