88
99class Sam3VideoNodeSize (desc .MultiDynamicNodeSize ):
1010 def computeSize (self , node ):
11- # if node.attribute(self._params[0]).isLink:
12- # return node.attribute(self._params[0]).inputLink.node.size
13-
14- # from pathlib import Path
15-
16- # input_path_param = node.attribute(self._params[0])
17- # extension_param = node.attribute(self._params[1])
18- # input_path = input_path_param.value
19- # extension = extension_param.value
20- # include_suffixes = [extension.lower(), extension.upper()]
21-
2211 size = 1
23- # if Path(input_path).is_dir():
24- # import itertools
25- # image_paths = list(itertools.chain(*(Path(input_path).glob(f'*.{suffix}') for suffix in include_suffixes)))
26- # size = len(image_paths)
27-
2812 return size
2913
3014class VideoSegmentationSam3 (desc .Node ):
3115 size = Sam3VideoNodeSize (['input' , 'extensionIn' ])
3216 gpu = desc .Level .INTENSIVE
33- #parallelization = desc.Parallelization(blockSize=50)
3417
3518 category = "Utils"
3619 documentation = """
37- Based on the Segment Anything model 3, the node generates a binary mask from a text prompt.
38- It is strongly advised to launch a first segmentation using only a text prompt .
20+ Based on the Segment Anything model 3, the node generates a binary mask from a text prompt, a single bounding box or
21+ a set of positive and negative clicks (Clicks In/Out) .
3922Two masks are generated, a binary one and a colored one that the indexes of every sub masks.
4023Object Ids are color encoded as follow:
4124 0:[1,0,0] = xff0000
@@ -160,29 +143,12 @@ class VideoSegmentationSam3(desc.Node):
160143 keyType = "viewId" ,
161144 ),
162145 ),
163- desc .ShapeList (
164- name = "positiveBoxes" ,
165- label = "Positive Boxes" ,
166- description = "Prompt: Positive Bounding Boxes" ,
167- shape = desc .Rectangle (
168- name = "bbox" ,
169- label = "Bounding Box" ,
170- description = "Rectangle." ,
171- keyable = True ,
172- keyType = "viewId" ,
173- ),
174- ),
175- desc .ShapeList (
176- name = "negativeBoxes" ,
177- label = "Negative Boxes" ,
178- description = "Prompt: Negative Bounding Boxes" ,
179- shape = desc .Rectangle (
180- name = "bbox" ,
181- label = "Bounding Box" ,
182- description = "Rectangle." ,
183- keyable = True ,
184- keyType = "viewId" ,
185- ),
146+ desc .Rectangle (
147+ name = "boxPrompt" ,
148+ label = "Box Prompt" ,
149+ description = "Single bounding box used as initial prompt." ,
150+ keyable = True ,
151+ keyType = "viewId"
186152 ),
187153 ]
188154
@@ -259,19 +225,18 @@ def normalize_click(self, click_xy, img_w, img_h, PAR, orientation):
259225
260226 def getBboxDictWithViewIdAsKeyFromShape (self , shape ):
261227 bboxDictFromShape = {}
262- shapesBBoxesIn = shape .getShapesAsDict ()
263- if shapesBBoxesIn :
264- for sh in shapesBBoxesIn :
265- for key in sh ["observations" ]:
266- xc = sh ["observations" ][key ]["center" ]["x" ]
267- yc = sh ["observations" ][key ]["center" ]["y" ]
268- w = sh ["observations" ][key ]["size" ]["width" ]
269- h = sh ["observations" ][key ]["size" ]["height" ]
270- bb = [xc - w / 2 , yc - h / 2 , w , h ]
271- if key in bboxDictFromShape :
272- bboxDictFromShape [key ].append (bb )
273- else :
274- bboxDictFromShape [key ] = [bb ]
228+ sh = shape .getShapeAsDict ()
229+ if sh :
230+ for key in sh ["observations" ]:
231+ xc = sh ["observations" ][key ]["center" ]["x" ]
232+ yc = sh ["observations" ][key ]["center" ]["y" ]
233+ w = sh ["observations" ][key ]["size" ]["width" ]
234+ h = sh ["observations" ][key ]["size" ]["height" ]
235+ bb = [xc - w / 2 , yc - h / 2 , w , h ]
236+ if key in bboxDictFromShape :
237+ bboxDictFromShape [key ].append (bb )
238+ else :
239+ bboxDictFromShape [key ] = [bb ]
275240 return bboxDictFromShape
276241
277242 def normalize_bbox (self , bbox_xywh , img_w , img_h , PAR , orientation ):
@@ -345,8 +310,7 @@ def processChunk(self, chunk):
345310
346311 posClickDictFromShape = self .getClickDictWithViewIdAsKeyFromShape (chunk .node .positiveClicks )
347312 negClickDictFromShape = self .getClickDictWithViewIdAsKeyFromShape (chunk .node .negativeClicks )
348- posBboxDictFromShape = self .getBboxDictWithViewIdAsKeyFromShape (chunk .node .positiveBoxes )
349- negBboxDictFromShape = self .getBboxDictWithViewIdAsKeyFromShape (chunk .node .negativeBoxes )
313+ posBboxDictFromShape = self .getBboxDictWithViewIdAsKeyFromShape (chunk .node .boxPrompt )
350314
351315 metadata_deep_model = {}
352316 metadata_deep_model ["Meshroom:mrSegmentation:DeepModelName" ] = "SegmentAnything"
@@ -408,14 +372,6 @@ def processChunk(self, chunk):
408372 bboxes [frameId ][0 ].append (bbox )
409373 bboxes [frameId ][1 ].append (1 )
410374
411- if viewId is not None and str (viewId ) in negBboxDictFromShape :
412- if frameId not in bboxes :
413- bboxes [frameId ] = ([],[])
414- for bbox in negBboxDictFromShape [viewId ]:
415- bbox = self .normalize_bbox (bbox , img .shape [1 ], img .shape [0 ], PAR , orientation )
416- bboxes [frameId ][0 ].append (bbox )
417- bboxes [frameId ][1 ].append (0 )
418-
419375 chunk .logger .debug (f"clicks = { clicks } " )
420376 chunk .logger .debug (f"bboxes = { bboxes } " )
421377
@@ -427,7 +383,6 @@ def processChunk(self, chunk):
427383 )
428384 session_id = response ["session_id" ]
429385
430- #if chunk.node.prompt.value != "":
431386 response = video_predictor .handle_request (
432387 request = dict (
433388 type = "add_prompt" ,
0 commit comments