Skip to content

Commit 273e030

Browse files
committed
Add tiling mode in VideoSegmentationSam3Boxes
1 parent 8af9d20 commit 273e030

3 files changed

Lines changed: 279 additions & 68 deletions

File tree

meshroom/imageSegmentation/VideoSegmentationSam3Boxes.py

Lines changed: 144 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.0"
1+
__version__ = "2.0"
22

33
import os
44
from pathlib import Path
@@ -16,8 +16,42 @@ class VideoSegmentationSam3Boxes(desc.Node):
1616

1717
category = "Segmentation"
1818
documentation = """
19-
Based on the Segment Anything video predictor model 3, the node generates binary masks from a set of
20-
bounding boxes contained in a json file.
19+
## Video Segmentation with SAM3 Bounding Boxes
20+
21+
This node generates binary segmentation masks for video sequences using the **Segment Anything Model 3 (SAM3)** video predictor.
22+
23+
### Inputs
24+
Segmentation is driven by bounding boxes provided in a `bboxes.json` file, typically generated by the **VideoSegmentationSam3Text** node.
25+
26+
### Multi-Resolution Support
27+
To improve segmentation quality on small objects, the node can combine source images at three resolutions:
28+
- **Native resolution** (required)
29+
- **Upscaled x2** (optional)
30+
- **Upscaled x4** (optional)
31+
32+
When tiling is disabled, the resolution used for each bounding box is selected automatically based on its size:
33+
- Box smaller than **252×252** pixels → x4 image (if available)
34+
- Box smaller than **504×504** pixels → x2 image (if available)
35+
- Otherwise → native resolution image
36+
37+
The `Round Crop Size` option (only available when tiling is disabled) snaps crop dimensions to **252, 504, or 1008** pixels, which can improve model accuracy for small bounding boxes.
38+
39+
### Tiling Mode
40+
When **Enable Tiling** is active, large bounding boxes are subdivided into overlapping tiles before being passed to the model. This allows processing of high-resolution regions that would otherwise exceed the model's input capacity.
41+
Key parameters:
42+
- **Target Tile Size**: Target size (in pixels) for each tile.
43+
- **Minimal Overlap**: Minimum pixel overlap between adjacent tiles to avoid boundary artifacts.
44+
45+
> **Note:** Tiling and multi-resolution upscaling are mutually exclusive. When tiling is enabled, native resolution images are always used.
46+
47+
### Computation Logic
48+
For each tracked object (identified by a text prompt and an object ID):
49+
1. The bounding boxes are extracted from `bboxes.json` and grouped into temporal chunks.
50+
2. Each chunk is optionally split into tiles.
51+
3. Cropped image sequences are fed to the SAM3 video predictor.
52+
4. The model propagates masks across all frames in the chunk.
53+
5. Predicted masks are resized and composited back into full-resolution mask images.
54+
6. Final masks are saved to disk, optionally inverted.
2155
"""
2256

2357
inputs = [
@@ -51,6 +85,33 @@ class VideoSegmentationSam3Boxes(desc.Node):
5185
description="Folder containing the bboxes.json file associated to the sfmData used as input.",
5286
value="",
5387
),
88+
desc.BoolParam(
89+
name="enableTiling",
90+
label="Enable Tiling",
91+
description="Enable tiling in big boxes.",
92+
value=True,
93+
),
94+
desc.IntParam(
95+
name="targetTileSize",
96+
label="Target Tile Size",
97+
description="Tile size.",
98+
value=504,
99+
enabled=lambda node: node.enableTiling.value,
100+
),
101+
desc.IntParam(
102+
name="minimalOverlap",
103+
label="Minimal Overlap",
104+
description="minimal tile overlap.",
105+
value=16,
106+
enabled=lambda node: node.enableTiling.value,
107+
),
108+
desc.BoolParam(
109+
name="roundCropSize",
110+
label="Round Crop Size",
111+
description="Round crop size to 252, 504 or 1008 for tube with smaller bounding boxes.",
112+
value=True,
113+
enabled=lambda node: not node.enableTiling.value,
114+
),
54115
desc.File(
55116
name="segmentationModelPath",
56117
label="Segmentation Model",
@@ -157,9 +218,11 @@ def processChunk(self, chunk):
157218
frame_w = chunk_image_paths[0][3]
158219
frame_h = chunk_image_paths[0][4]
159220
par = chunk_image_paths[0][5]
221+
firstFrameId = chunk_image_paths[0][2]
160222
x2_ok = os.path.exists(chunk.node.inputx2.value)
161223
x4_ok = os.path.exists(chunk.node.inputx4.value)
162-
bboxes = bboxUtils.extract_tracking(json_path, frame_w, frame_h, x2_ok, x4_ok, par)
224+
roundCrop = chunk.node.roundCropSize.value
225+
bboxes = bboxUtils.extract_tracking(json_path, frame_w, frame_h, x2_ok, x4_ok, roundCrop, par)
163226

164227
logger.debug(f"bboxes.keys() = {bboxes.keys()}")
165228

@@ -178,65 +241,85 @@ def processChunk(self, chunk):
178241
logger.info(f"key = {key} ; text prompt = {textPrompt} ; obj_id = {obj_id}")
179242

180243
for frame_chunk in frame_chunks:
181-
logger.info(frame_chunk)
182-
pil_images = []
183-
firstFrameId = frame_chunk.start_frame
184-
for frame_idx, box in sorted(frame_chunk.boxes.items()):
185-
x1, y1, x2, y2 = bboxUtils.box_to_display(box, sourceInfo["PAR"])
186-
box_w = x2 - x1
187-
box_h = y2 - y1
188-
189-
if box_w == 252 and box_h == 252:
190-
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][7]), True)
191-
imgBuf = oiio.ImageBuf(img)
192-
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(4*x1, 4*x2, 4*y1, 4*y2))
193-
elif box_w == 504 and box_h == 504:
194-
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][6]), True)
195-
imgBuf = oiio.ImageBuf(img)
196-
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(2*x1, 2*x2, 2*y1, 2*y2))
197-
else:
198-
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][0]), True)
199-
imgBuf = oiio.ImageBuf(img)
200-
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(x1, x2, y1, y2))
201-
202-
img_crop = imgBuf.get_pixels(format=oiio.FLOAT)
203-
pil_images.append(Image.fromarray((255.0*img_crop).astype("uint8")))
204-
205-
response = video_predictor.handle_request(
206-
request=dict(
207-
type="start_session",
208-
resource_path=pil_images,
244+
logger.info(f"frame_chunk:\{frame_chunk}")
245+
logger.debug(f"{frame_chunk.boxes}")
246+
247+
chunk_tiles = [frame_chunk]
248+
if chunk.node.enableTiling.value:
249+
chunk_tiles = bboxUtils.tile_chunk(frame_chunk, chunk.node.targetTileSize.value,
250+
chunk.node.minimalOverlap.value, sourceInfo["PAR"], logger)
251+
# In tiling mode, avoid loading all frames for every new tiles
252+
full_pil_images = {}
253+
if chunk.node.enableTiling.value:
254+
for frameId, _ in chunk_tiles[0].boxes.items():
255+
img, h_ori, w_ori, PAR, orientation = image.loadImage(str(chunk_image_paths[frameId - firstFrameId][0]), True)
256+
full_pil_images[frameId] = img
257+
258+
logger.info(f"chunk_tiles:\{chunk_tiles}")
259+
260+
for chunk_tile in chunk_tiles:
261+
logger.debug(f"{chunk_tile.boxes}")
262+
263+
pil_images = []
264+
for frame_idx, box in sorted(chunk_tile.boxes.items()):
265+
x1, y1, x2, y2 = bboxUtils.box_to_display(box, sourceInfo["PAR"])
266+
box_w = x2 - x1
267+
box_h = y2 - y1
268+
269+
if box_w <= 252 and box_h <= 252 and x4_ok and not chunk.node.enableTiling.value:
270+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][7]), True)
271+
imgBuf = oiio.ImageBuf(img)
272+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(4*x1, 4*x2, 4*y1, 4*y2))
273+
elif box_w <= 504 and box_h <= 504 and x2_ok and not chunk.node.enableTiling.value:
274+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][6]), True)
275+
imgBuf = oiio.ImageBuf(img)
276+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(2*x1, 2*x2, 2*y1, 2*y2))
277+
elif not chunk.node.enableTiling.value:
278+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][0]), True)
279+
imgBuf = oiio.ImageBuf(img)
280+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(x1, x2, y1, y2))
281+
else:
282+
# use already loaded images
283+
imgBuf = oiio.ImageBuf(full_pil_images[frame_idx])
284+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(x1, x2, y1, y2))
285+
286+
img_crop = imgBuf.get_pixels(format=oiio.FLOAT)
287+
pil_images.append(Image.fromarray((255.0*img_crop).astype("uint8")))
288+
289+
response = video_predictor.handle_request(
290+
request=dict(
291+
type="start_session",
292+
resource_path=pil_images,
293+
)
294+
)
295+
session_id = response["session_id"]
296+
297+
video_predictor.handle_request(
298+
request=dict(
299+
type="add_prompt",
300+
session_id=session_id,
301+
frame_index=0,
302+
text=textPrompt,
209303
)
210-
)
211-
session_id = response["session_id"]
212-
213-
video_predictor.handle_request(
214-
request=dict(
215-
type="add_prompt",
216-
session_id=session_id,
217-
frame_index=0,
218-
text=textPrompt,
219304
)
220-
)
221-
outputs_per_frame = sam3Utils.propagateInVideo(video_predictor, session_id) #, fIdx, max_frame_num_to_track, track_dir)
222-
outputs_per_frame_visu = sam3Utils.prepareMasksForVisualization(outputs_per_frame)
223-
224-
for frame_idx, box in sorted(frame_chunk.boxes.items()):
225-
x1, y1, x2, y2 = box
226-
box_w = x2 - x1
227-
box_h = y2 - y1
228-
frameId = frame_idx - firstFrameId
229-
for key, maskBoxProb in outputs_per_frame_visu[frameId].items():
230-
mask = maskBoxProb["mask"]
231-
buf_in = oiio.ImageBuf(mask.astype('float32'))
232-
buf_out = oiio.ImageBufAlgo.resample(buf_in, roi=oiio.ROI(0, box_w, 0, box_h))
233-
mask = buf_out.get_pixels().reshape(box_h, box_w, 1)
234-
tgt = full_mask_images[frame_idx][y1:y2 ,x1:x2, :]
235-
bool_mask = mask.squeeze() > 0
236-
tgt[bool_mask] = [255, 255, 255]
237-
238-
video_predictor.handle_request(request=dict(type="close_session", session_id=session_id))
239-
305+
outputs_per_frame = sam3Utils.propagateInVideo(video_predictor, session_id)
306+
outputs_per_frame_visu = sam3Utils.prepareMasksForVisualization(outputs_per_frame)
307+
308+
for frame_idx, box in sorted(chunk_tile.boxes.items()):
309+
x1, y1, x2, y2 = box
310+
box_w = x2 - x1
311+
box_h = y2 - y1
312+
frameId = frame_idx - chunk_tile.start_frame
313+
for key, maskBoxProb in outputs_per_frame_visu[frameId].items():
314+
mask = maskBoxProb["mask"]
315+
buf_in = oiio.ImageBuf(mask.astype('float32'))
316+
buf_out = oiio.ImageBufAlgo.resample(buf_in, roi=oiio.ROI(0, box_w, 0, box_h))
317+
mask = buf_out.get_pixels().reshape(box_h, box_w, 1)
318+
tgt = full_mask_images[frame_idx][y1:y2 ,x1:x2, :]
319+
bool_mask = mask.squeeze() > 0
320+
tgt[bool_mask] = [255, 255, 255]
321+
322+
video_predictor.handle_request(request=dict(type="close_session", session_id=session_id))
240323

241324
for frameId, image_path in enumerate(chunk_image_paths):
242325
if chunk.node.maskInvert.value:

meshroom/rotoPersons.mg

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
{
2+
"header": {
3+
"releaseVersion": "2026.1.0+develop",
4+
"fileVersion": "2.0",
5+
"nodesVersions": {
6+
"CameraInit": "12.1",
7+
"CopyFiles": "1.3",
8+
"VideoSegmentationSam3Boxes": "2.0",
9+
"VideoSegmentationSam3Text": "1.0"
10+
},
11+
"template": true
12+
},
13+
"graph": {
14+
"CameraInit_1": {
15+
"nodeType": "CameraInit",
16+
"position": [
17+
-452,
18+
94
19+
],
20+
"inputs": {}
21+
},
22+
"CopyFiles_1": {
23+
"nodeType": "CopyFiles",
24+
"position": [
25+
229,
26+
73
27+
],
28+
"inputs": {
29+
"output": "{VideoSegmentationSam3Boxes_1.output}"
30+
}
31+
},
32+
"VideoSegmentationSam3Boxes_1": {
33+
"nodeType": "VideoSegmentationSam3Boxes",
34+
"position": [
35+
9,
36+
41
37+
],
38+
"inputs": {
39+
"input": "{VideoSegmentationSam3Text_1.input}",
40+
"masksFolder": "{VideoSegmentationSam3Text_1.output}",
41+
"bboxesFolder": "{VideoSegmentationSam3Text_1.output}",
42+
"verboseLevel": "debug"
43+
}
44+
},
45+
"VideoSegmentationSam3Text_1": {
46+
"nodeType": "VideoSegmentationSam3Text",
47+
"position": [
48+
-221,
49+
61
50+
],
51+
"inputs": {
52+
"input": "{CameraInit_1.output}",
53+
"timeSlicing": true,
54+
"sliceSize": 64,
55+
"verboseLevel": "debug"
56+
}
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)