1- __version__ = "1 .0"
1+ __version__ = "2 .0"
22
33import os
44from pathlib import Path
@@ -16,8 +16,42 @@ class VideoSegmentationSam3Boxes(desc.Node):
1616
1717 category = "Segmentation"
1818 documentation = """
19- Based on the Segment Anything video predictor model 3, the node generates binary masks from a set of
20- bounding boxes contained in a json file.
19+ ## Video Segmentation with SAM3 Bounding Boxes
20+
21+ This node generates binary segmentation masks for video sequences using the **Segment Anything Model 3 (SAM3)** video predictor.
22+
23+ ### Inputs
24+ Segmentation is driven by bounding boxes provided in a `bboxes.json` file, typically generated by the **VideoSegmentationSam3Text** node.
25+
26+ ### Multi-Resolution Support
27+ To improve segmentation quality on small objects, the node can combine source images at three resolutions:
28+ - **Native resolution** (required)
29+ - **Upscaled x2** (optional)
30+ - **Upscaled x4** (optional)
31+
32+ When tiling is disabled, the resolution used for each bounding box is selected automatically based on its size:
33+ - Box smaller than **252×252** pixels → x4 image (if available)
34+ - Box smaller than **504×504** pixels → x2 image (if available)
35+ - Otherwise → native resolution image
36+
37+ The `Round Crop Size` option (only available when tiling is disabled) snaps crop dimensions to **252, 504, or 1008** pixels, which can improve model accuracy for small bounding boxes.
38+
39+ ### Tiling Mode
40+ When **Enable Tiling** is active, large bounding boxes are subdivided into overlapping tiles before being passed to the model. This allows processing of high-resolution regions that would otherwise exceed the model's input capacity.
41+ Key parameters:
42+ - **Target Tile Size**: Target size (in pixels) for each tile.
43+ - **Minimal Overlap**: Minimum pixel overlap between adjacent tiles to avoid boundary artifacts.
44+
45+ > **Note:** Tiling and multi-resolution upscaling are mutually exclusive. When tiling is enabled, native resolution images are always used.
46+
47+ ### Computation Logic
48+ For each tracked object (identified by a text prompt and an object ID):
49+ 1. The bounding boxes are extracted from `bboxes.json` and grouped into temporal chunks.
50+ 2. Each chunk is optionally split into tiles.
51+ 3. Cropped image sequences are fed to the SAM3 video predictor.
52+ 4. The model propagates masks across all frames in the chunk.
53+ 5. Predicted masks are resized and composited back into full-resolution mask images.
54+ 6. Final masks are saved to disk, optionally inverted.
2155"""
2256
2357 inputs = [
@@ -51,6 +85,33 @@ class VideoSegmentationSam3Boxes(desc.Node):
5185 description = "Folder containing the bboxes.json file associated to the sfmData used as input." ,
5286 value = "" ,
5387 ),
88+ desc .BoolParam (
89+ name = "enableTiling" ,
90+ label = "Enable Tiling" ,
91+ description = "Enable tiling in big boxes." ,
92+ value = True ,
93+ ),
94+ desc .IntParam (
95+ name = "targetTileSize" ,
96+ label = "Target Tile Size" ,
97+ description = "Tile size." ,
98+ value = 504 ,
99+ enabled = lambda node : node .enableTiling .value ,
100+ ),
101+ desc .IntParam (
102+ name = "minimalOverlap" ,
103+ label = "Minimal Overlap" ,
104+ description = "minimal tile overlap." ,
105+ value = 16 ,
106+ enabled = lambda node : node .enableTiling .value ,
107+ ),
108+ desc .BoolParam (
109+ name = "roundCropSize" ,
110+ label = "Round Crop Size" ,
111+ description = "Round crop size to 252, 504 or 1008 for tube with smaller bounding boxes." ,
112+ value = True ,
113+ enabled = lambda node : not node .enableTiling .value ,
114+ ),
54115 desc .File (
55116 name = "segmentationModelPath" ,
56117 label = "Segmentation Model" ,
@@ -157,9 +218,11 @@ def processChunk(self, chunk):
157218 frame_w = chunk_image_paths [0 ][3 ]
158219 frame_h = chunk_image_paths [0 ][4 ]
159220 par = chunk_image_paths [0 ][5 ]
221+ firstFrameId = chunk_image_paths [0 ][2 ]
160222 x2_ok = os .path .exists (chunk .node .inputx2 .value )
161223 x4_ok = os .path .exists (chunk .node .inputx4 .value )
162- bboxes = bboxUtils .extract_tracking (json_path , frame_w , frame_h , x2_ok , x4_ok , par )
224+ roundCrop = chunk .node .roundCropSize .value
225+ bboxes = bboxUtils .extract_tracking (json_path , frame_w , frame_h , x2_ok , x4_ok , roundCrop , par )
163226
164227 logger .debug (f"bboxes.keys() = { bboxes .keys ()} " )
165228
@@ -178,65 +241,85 @@ def processChunk(self, chunk):
178241 logger .info (f"key = { key } ; text prompt = { textPrompt } ; obj_id = { obj_id } " )
179242
180243 for frame_chunk in frame_chunks :
181- logger .info (frame_chunk )
182- pil_images = []
183- firstFrameId = frame_chunk .start_frame
184- for frame_idx , box in sorted (frame_chunk .boxes .items ()):
185- x1 , y1 , x2 , y2 = bboxUtils .box_to_display (box , sourceInfo ["PAR" ])
186- box_w = x2 - x1
187- box_h = y2 - y1
188-
189- if box_w == 252 and box_h == 252 :
190- img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][7 ]), True )
191- imgBuf = oiio .ImageBuf (img )
192- imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (4 * x1 , 4 * x2 , 4 * y1 , 4 * y2 ))
193- elif box_w == 504 and box_h == 504 :
194- img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][6 ]), True )
195- imgBuf = oiio .ImageBuf (img )
196- imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (2 * x1 , 2 * x2 , 2 * y1 , 2 * y2 ))
197- else :
198- img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][0 ]), True )
199- imgBuf = oiio .ImageBuf (img )
200- imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (x1 , x2 , y1 , y2 ))
201-
202- img_crop = imgBuf .get_pixels (format = oiio .FLOAT )
203- pil_images .append (Image .fromarray ((255.0 * img_crop ).astype ("uint8" )))
204-
205- response = video_predictor .handle_request (
206- request = dict (
207- type = "start_session" ,
208- resource_path = pil_images ,
244+ logger .info (f"frame_chunk:\{ frame_chunk } " )
245+ logger .debug (f"{ frame_chunk .boxes } " )
246+
247+ chunk_tiles = [frame_chunk ]
248+ if chunk .node .enableTiling .value :
249+ chunk_tiles = bboxUtils .tile_chunk (frame_chunk , chunk .node .targetTileSize .value ,
250+ chunk .node .minimalOverlap .value , sourceInfo ["PAR" ], logger )
251+ # In tiling mode, avoid loading all frames for every new tiles
252+ full_pil_images = {}
253+ if chunk .node .enableTiling .value :
254+ for frameId , _ in chunk_tiles [0 ].boxes .items ():
255+ img , h_ori , w_ori , PAR , orientation = image .loadImage (str (chunk_image_paths [frameId - firstFrameId ][0 ]), True )
256+ full_pil_images [frameId ] = img
257+
258+ logger .info (f"chunk_tiles:\{ chunk_tiles } " )
259+
260+ for chunk_tile in chunk_tiles :
261+ logger .debug (f"{ chunk_tile .boxes } " )
262+
263+ pil_images = []
264+ for frame_idx , box in sorted (chunk_tile .boxes .items ()):
265+ x1 , y1 , x2 , y2 = bboxUtils .box_to_display (box , sourceInfo ["PAR" ])
266+ box_w = x2 - x1
267+ box_h = y2 - y1
268+
269+ if box_w <= 252 and box_h <= 252 and x4_ok and not chunk .node .enableTiling .value :
270+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][7 ]), True )
271+ imgBuf = oiio .ImageBuf (img )
272+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (4 * x1 , 4 * x2 , 4 * y1 , 4 * y2 ))
273+ elif box_w <= 504 and box_h <= 504 and x2_ok and not chunk .node .enableTiling .value :
274+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][6 ]), True )
275+ imgBuf = oiio .ImageBuf (img )
276+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (2 * x1 , 2 * x2 , 2 * y1 , 2 * y2 ))
277+ elif not chunk .node .enableTiling .value :
278+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][0 ]), True )
279+ imgBuf = oiio .ImageBuf (img )
280+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (x1 , x2 , y1 , y2 ))
281+ else :
282+ # use already loaded images
283+ imgBuf = oiio .ImageBuf (full_pil_images [frame_idx ])
284+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (x1 , x2 , y1 , y2 ))
285+
286+ img_crop = imgBuf .get_pixels (format = oiio .FLOAT )
287+ pil_images .append (Image .fromarray ((255.0 * img_crop ).astype ("uint8" )))
288+
289+ response = video_predictor .handle_request (
290+ request = dict (
291+ type = "start_session" ,
292+ resource_path = pil_images ,
293+ )
294+ )
295+ session_id = response ["session_id" ]
296+
297+ video_predictor .handle_request (
298+ request = dict (
299+ type = "add_prompt" ,
300+ session_id = session_id ,
301+ frame_index = 0 ,
302+ text = textPrompt ,
209303 )
210- )
211- session_id = response ["session_id" ]
212-
213- video_predictor .handle_request (
214- request = dict (
215- type = "add_prompt" ,
216- session_id = session_id ,
217- frame_index = 0 ,
218- text = textPrompt ,
219304 )
220- )
221- outputs_per_frame = sam3Utils .propagateInVideo (video_predictor , session_id ) #, fIdx, max_frame_num_to_track, track_dir)
222- outputs_per_frame_visu = sam3Utils .prepareMasksForVisualization (outputs_per_frame )
223-
224- for frame_idx , box in sorted (frame_chunk .boxes .items ()):
225- x1 , y1 , x2 , y2 = box
226- box_w = x2 - x1
227- box_h = y2 - y1
228- frameId = frame_idx - firstFrameId
229- for key , maskBoxProb in outputs_per_frame_visu [frameId ].items ():
230- mask = maskBoxProb ["mask" ]
231- buf_in = oiio .ImageBuf (mask .astype ('float32' ))
232- buf_out = oiio .ImageBufAlgo .resample (buf_in , roi = oiio .ROI (0 , box_w , 0 , box_h ))
233- mask = buf_out .get_pixels ().reshape (box_h , box_w , 1 )
234- tgt = full_mask_images [frame_idx ][y1 :y2 ,x1 :x2 , :]
235- bool_mask = mask .squeeze () > 0
236- tgt [bool_mask ] = [255 , 255 , 255 ]
237-
238- video_predictor .handle_request (request = dict (type = "close_session" , session_id = session_id ))
239-
305+ outputs_per_frame = sam3Utils .propagateInVideo (video_predictor , session_id )
306+ outputs_per_frame_visu = sam3Utils .prepareMasksForVisualization (outputs_per_frame )
307+
308+ for frame_idx , box in sorted (chunk_tile .boxes .items ()):
309+ x1 , y1 , x2 , y2 = box
310+ box_w = x2 - x1
311+ box_h = y2 - y1
312+ frameId = frame_idx - chunk_tile .start_frame
313+ for key , maskBoxProb in outputs_per_frame_visu [frameId ].items ():
314+ mask = maskBoxProb ["mask" ]
315+ buf_in = oiio .ImageBuf (mask .astype ('float32' ))
316+ buf_out = oiio .ImageBufAlgo .resample (buf_in , roi = oiio .ROI (0 , box_w , 0 , box_h ))
317+ mask = buf_out .get_pixels ().reshape (box_h , box_w , 1 )
318+ tgt = full_mask_images [frame_idx ][y1 :y2 ,x1 :x2 , :]
319+ bool_mask = mask .squeeze () > 0
320+ tgt [bool_mask ] = [255 , 255 , 255 ]
321+
322+ video_predictor .handle_request (request = dict (type = "close_session" , session_id = session_id ))
240323
241324 for frameId , image_path in enumerate (chunk_image_paths ):
242325 if chunk .node .maskInvert .value :
0 commit comments