1717from ..utility .utility import tensor2pil , pil2tensor
1818
1919script_directory = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
20+ device = model_management .get_torch_device ()
21+ offload_device = model_management .unet_offload_device ()
2022
2123class BatchCLIPSeg :
2224
@@ -997,6 +999,7 @@ def INPUT_TYPES(cls):
997999- fill_holes: fill holes in the mask (slow)"""
9981000
9991001 def expand_mask (self , mask , expand , tapered_corners , flip_input , blur_radius , incremental_expandrate , lerp_alpha , decay_factor , fill_holes = False ):
1002+ import kornia .morphology as morph
10001003 alpha = lerp_alpha
10011004 decay = decay_factor
10021005 if flip_input :
@@ -1010,30 +1013,45 @@ def expand_mask(self, mask, expand, tapered_corners, flip_input, blur_radius, in
10101013 previous_output = None
10111014 current_expand = expand
10121015 for m in growmask :
1013- output = m .numpy ().astype (np .float32 )
1014- for _ in range (abs (round (current_expand ))):
1015- if current_expand < 0 :
1016- output = scipy .ndimage .grey_erosion (output , footprint = kernel )
1016+ output = m .unsqueeze (0 ).unsqueeze (0 ).to (device ) # Add batch and channel dims for kornia
1017+ if abs (round (current_expand )) > 0 :
1018+ # Create kernel - kornia expects kernel on same device as input
1019+ if tapered_corners :
1020+ kernel = torch .tensor ([[0 , 1 , 0 ],
1021+ [1 , 1 , 1 ],
1022+ [0 , 1 , 0 ]], dtype = torch .float32 , device = output .device )
10171023 else :
1018- output = scipy .ndimage .grey_dilation (output , footprint = kernel )
1024+ kernel = torch .tensor ([[1 , 1 , 1 ],
1025+ [1 , 1 , 1 ],
1026+ [1 , 1 , 1 ]], dtype = torch .float32 , device = output .device )
1027+
1028+ for _ in range (abs (round (current_expand ))):
1029+ if current_expand < 0 :
1030+ output = morph .erosion (output , kernel )
1031+ else :
1032+ output = morph .dilation (output , kernel )
1033+
1034+ output = output .squeeze (0 ).squeeze (0 ) # Remove batch and channel dims
1035+
10191036 if current_expand < 0 :
10201037 current_expand -= abs (incremental_expandrate )
10211038 else :
10221039 current_expand += abs (incremental_expandrate )
1040+
10231041 if fill_holes :
1042+ # For fill_holes, you might need to keep using scipy or implement GPU version
10241043 binary_mask = output > 0
1025- output = scipy .ndimage .binary_fill_holes (binary_mask )
1026- output = output .astype (np .float32 ) * 255
1027- output = torch .from_numpy (output )
1044+ output_np = binary_mask .cpu ().numpy ()
1045+ filled = scipy .ndimage .binary_fill_holes (output_np )
1046+ output = torch .from_numpy (filled .astype (np .float32 )).to (output .device )
1047+
10281048 if alpha < 1.0 and previous_output is not None :
1029- # Interpolate between the previous and current frame
10301049 output = alpha * output + (1 - alpha ) * previous_output
10311050 if decay < 1.0 and previous_output is not None :
1032- # Add the decayed previous output to the current frame
10331051 output += decay * previous_output
10341052 output = output / output .max ()
10351053 previous_output = output
1036- out .append (output )
1054+ out .append (output . cpu () )
10371055
10381056 if blur_radius != 0 :
10391057 # Convert the tensor list to PIL images, apply blur, and convert back
0 commit comments