Skip to content

Commit db93015

Browse files
committed
Add VideoSegmentationSam3Boxes node
1 parent eeb972f commit db93015

3 files changed

Lines changed: 563 additions & 2 deletions

File tree

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
__version__ = "1.0"
2+
3+
from functools import total_ordering
4+
import os
5+
from pathlib import Path
6+
7+
from meshroom.core import desc
8+
from meshroom.core.utils import VERBOSE_LEVEL
9+
from pyalicevision import parallelization as avpar
10+
11+
import logging
12+
logger = logging.getLogger("VideoSegmentationSam3Boxes")
13+
14+
class VideoSegmentationSam3Boxes(desc.Node):
15+
size = avpar.DynamicViewsSize("input")
16+
gpu = desc.Level.EXTREME
17+
18+
category = "Segmentation"
19+
documentation = """
20+
Based on the Segment Anything video predictor model 3, the node generates binary masks from a set of
21+
bounding boxes contained in a json file.
22+
"""
23+
24+
inputs = [
25+
desc.File(
26+
name="input",
27+
label="Input",
28+
description="SfMData file.",
29+
value="",
30+
),
31+
desc.File(
32+
name="inputx2",
33+
label="Inputx2",
34+
description="Folder containing source images upscale by 2.",
35+
value="",
36+
),
37+
desc.File(
38+
name="inputx4",
39+
label="Inputx4",
40+
description="Folder containing source images upscale by 4.",
41+
value="",
42+
),
43+
desc.File(
44+
name="masksFolder",
45+
label="Masks Folder",
46+
description="Folder containing the masks computed at original resolution.",
47+
value="",
48+
),
49+
desc.File(
50+
name="bboxesFolder",
51+
label="Bounding Boxes Folder",
52+
description="Folder containing the bboxes.json file associated to the sfmData used as input.",
53+
value="",
54+
),
55+
desc.File(
56+
name="segmentationModelPath",
57+
label="Segmentation Model",
58+
description="Weights file for the segmentation model.",
59+
value="${RDS_SAM3_MODEL_PATH}",
60+
),
61+
desc.BoolParam(
62+
name="maskInvert",
63+
label="Invert Masks",
64+
description="Invert mask values. If selected, the pixels corresponding to the mask will be set to 0 instead of 255.",
65+
value=False,
66+
),
67+
desc.BoolParam(
68+
name="useGpu",
69+
label="Use GPU",
70+
description="Use GPU for computation if available.",
71+
value=True,
72+
invalidate=False,
73+
),
74+
desc.BoolParam(
75+
name="keepFilename",
76+
label="Keep Filename",
77+
description="Keep the filename of the inputs for the outputs.",
78+
value=True,
79+
),
80+
desc.ChoiceParam(
81+
name="extensionOut",
82+
label="Output File Extension",
83+
description="Output image file extension.\n"
84+
"If unset, the output file extension will match the input's if possible.",
85+
value="exr",
86+
values=["exr", "png", "jpg"],
87+
exclusive=True,
88+
),
89+
desc.ChoiceParam(
90+
name="verboseLevel",
91+
label="Verbose Level",
92+
description="Verbosity level (fatal, error, warning, info, debug).",
93+
value="info",
94+
values=VERBOSE_LEVEL,
95+
exclusive=True,
96+
),
97+
]
98+
99+
outputs = [
100+
desc.File(
101+
name="output",
102+
label="Masks Folder",
103+
description="Output path for the masks.",
104+
value="{nodeCacheFolder}",
105+
),
106+
desc.File(
107+
name="masks",
108+
label="Masks",
109+
description="Generated segmentation masks.",
110+
semantic="image",
111+
value=lambda attr: "{nodeCacheFolder}/" + ("<FILESTEM>" if attr.node.keepFilename.value else "<VIEW_ID>") + "." + attr.node.extensionOut.value,
112+
),
113+
]
114+
115+
def preprocess(self, node):
116+
import re
117+
input_path = node.input.value
118+
image_paths = get_image_paths_list(input_path, node.inputx2.value, node.inputx4.value)
119+
if len(image_paths) == 0:
120+
raise FileNotFoundError(f'No image files found in {input_path}')
121+
self.image_paths = image_paths
122+
if node.bboxesFolder.value == "":
123+
raise ValueError(f'No file containing bounding boxes connected')
124+
125+
def processChunk(self, chunk):
126+
from segmentationRDS import image, sam3Utils, bboxUtils
127+
from sam3.model_builder import build_sam3_video_predictor
128+
import numpy as np
129+
import torch
130+
from pyalicevision import image as avimg
131+
from PIL import Image
132+
import OpenImageIO as oiio
133+
134+
try:
135+
logger.setLevel(chunk.node.verboseLevel.value.upper())
136+
137+
if not chunk.node.input:
138+
logger.warning("Nothing to segment")
139+
return
140+
if not chunk.node.output.value:
141+
return
142+
143+
logger.info("Chunk range from {} to {}".format(chunk.range.start, chunk.range.last))
144+
145+
chunk_image_paths = self.image_paths
146+
147+
if not os.path.exists(chunk.node.output.value):
148+
os.mkdir(chunk.node.output.value)
149+
150+
gpus_to_use = [torch.cuda.current_device()]
151+
video_predictor = build_sam3_video_predictor(checkpoint_path=chunk.node.segmentationModelPath.evalValue, gpus_to_use=gpus_to_use)
152+
153+
metadata_deep_model = {}
154+
metadata_deep_model["Meshroom:mrSegmentation:DeepModelName"] = "SegmentAnything"
155+
metadata_deep_model["Meshroom:mrSegmentation:DeepModelVersion"] = "sam3-Video-Crop"
156+
157+
# bboxes.json decoding
158+
json_path = os.path.join(chunk.node.bboxesFolder.value, "bboxes.json")
159+
frame_w = chunk_image_paths[0][3]
160+
frame_h = chunk_image_paths[0][4]
161+
par = chunk_image_paths[0][5]
162+
x2_ok = os.path.exists(chunk.node.inputx2.value)
163+
x4_ok = os.path.exists(chunk.node.inputx4.value)
164+
bboxes = bboxUtils.extract_tracking(json_path, frame_w, frame_h, x2_ok, x4_ok, par)
165+
166+
logger.info(f"bboxes.keys() = {bboxes.keys()}")
167+
168+
full_mask_images = {}
169+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[0][0]), True)
170+
sourceInfo = {"h_ori": h_ori, "w_ori": w_ori, "PAR": p_a_r, "orientation": orientation}
171+
for frameId, image_path in enumerate(chunk_image_paths):
172+
full_mask_images[image_path[2]] = np.zeros_like(img)
173+
174+
for key, frame_chunks in bboxes.items():
175+
176+
textPrompt = key.split('_')[0]
177+
obj_id = key.split('_')[1]
178+
logger.info(f"key = {key} ; text prompt = {textPrompt} ; obj_id = {obj_id}")
179+
180+
for frame_chunk in frame_chunks:
181+
logger.info(frame_chunk)
182+
pil_images = []
183+
mask_images = []
184+
firstFrameId = frame_chunk.start_frame
185+
for frame_idx, box in sorted(frame_chunk.boxes.items()):
186+
x1, y1, x2, y2 = box
187+
box_w = x2 - x1
188+
box_h = y2 - y1
189+
190+
if box_w == 252 and box_h == 252:
191+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][7]), True)
192+
imgBuf = oiio.ImageBuf(img)
193+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(4*x1, 4*x2, 4*y1, 4*y2))
194+
elif box_w == 504 and box_h == 504:
195+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][6]), True)
196+
imgBuf = oiio.ImageBuf(img)
197+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(2*x1, 2*x2, 2*y1, 2*y2))
198+
else:
199+
img, h_ori, w_ori, p_a_r, orientation = image.loadImage(str(chunk_image_paths[frame_idx - firstFrameId][0]), True)
200+
imgBuf = oiio.ImageBuf(img)
201+
imgBuf = oiio.ImageBufAlgo.crop(imgBuf, roi=oiio.ROI(x1, x2, y1, y2))
202+
203+
img_crop = imgBuf.get_pixels(format=oiio.FLOAT)
204+
pil_images.append(Image.fromarray((255.0*img_crop).astype("uint8")))
205+
mask_images.append(np.zeros_like(img_crop))
206+
207+
response = video_predictor.handle_request(
208+
request=dict(
209+
type="start_session",
210+
resource_path=pil_images,
211+
)
212+
)
213+
session_id = response["session_id"]
214+
215+
video_predictor.handle_request(
216+
request=dict(
217+
type="add_prompt",
218+
session_id=session_id,
219+
frame_index=0,
220+
text=textPrompt,
221+
)
222+
)
223+
outputs_per_frame = sam3Utils.propagateInVideo(video_predictor, session_id) #, fIdx, max_frame_num_to_track, track_dir)
224+
outputs_per_frame_visu = sam3Utils.prepareMasksForVisualization(outputs_per_frame)
225+
226+
for frame_idx, box in sorted(frame_chunk.boxes.items()):
227+
x1, y1, x2, y2 = box
228+
box_w = x2 - x1
229+
box_h = y2 - y1
230+
frameId = frame_idx - firstFrameId
231+
for key, maskBoxProb in outputs_per_frame_visu[frameId].items():
232+
mask = maskBoxProb["mask"]
233+
buf_in = oiio.ImageBuf(mask.astype('float32'))
234+
buf_out = oiio.ImageBufAlgo.resample(buf_in, roi=oiio.ROI(0, box_w, 0, box_h))
235+
mask = buf_out.get_pixels().reshape(box_h, box_w, 1)
236+
tgt = full_mask_images[frame_idx][y1:y2 ,x1:x2, :]
237+
bool_mask = mask.squeeze() > 0
238+
tgt[bool_mask] = [255, 255, 255]
239+
240+
video_predictor.handle_request(request=dict(type="close_session", session_id=session_id))
241+
242+
243+
for frameId, image_path in enumerate(chunk_image_paths):
244+
if chunk.node.maskInvert.value:
245+
mask = (full_mask_images[image_path[2]][:,:,0:1] == 0).astype('float32')
246+
else:
247+
mask = (full_mask_images[image_path[2]][:,:,0:1] > 0).astype('float32')
248+
logger.info(f"frameId: {frameId} - {image_path[0]}")
249+
250+
if chunk.node.keepFilename.value:
251+
outputFileMask = os.path.join(chunk.node.output.value, Path(image_path[0]).stem + "." + chunk.node.extensionOut.value)
252+
else:
253+
outputFileMask = os.path.join(chunk.node.output.value, str(image_path[1]) + "." + chunk.node.extensionOut.value)
254+
255+
optWrite = avimg.ImageWriteOptions()
256+
optWrite.toColorSpace(avimg.EImageColorSpace_NO_CONVERSION)
257+
if Path(outputFileMask).suffix.lower() == ".exr":
258+
optWrite.exrCompressionMethod(avimg.EImageExrCompression_stringToEnum("DWAA"))
259+
optWrite.exrCompressionLevel(300)
260+
261+
image.writeImage(outputFileMask, mask, sourceInfo["h_ori"], sourceInfo["w_ori"], sourceInfo["orientation"],
262+
sourceInfo["PAR"], metadata_deep_model, optWrite)
263+
264+
finally:
265+
torch.cuda.empty_cache()
266+
267+
268+
def get_image_paths_list(input_path, path_folder_x2 = "", path_folder_x4 = ""):
269+
from pyalicevision import sfmData, camera
270+
from pyalicevision import sfmDataIO
271+
from pathlib import Path
272+
273+
image_paths = []
274+
275+
if Path(input_path).suffix.lower() in [".sfm", ".abc"]:
276+
if Path(input_path).exists():
277+
dataAV = sfmData.SfMData()
278+
if sfmDataIO.load(dataAV, input_path, sfmDataIO.ALL):
279+
views = dataAV.getViews()
280+
for id, v in views.items():
281+
image_x1_path = Path(v.getImage().getImagePath())
282+
image_x1_name = image_x1_path.name
283+
image_x2_path = None
284+
if os.path.isfile(os.path.join(path_folder_x2, image_x1_name)):
285+
image_x2_path = os.path.join(path_folder_x2, image_x1_name)
286+
image_x4_path = None
287+
if os.path.isfile(os.path.join(path_folder_x4, image_x1_name)):
288+
image_x4_path = os.path.join(path_folder_x4, image_x1_name)
289+
intrinsic = dataAV.getIntrinsicSharedPtr(v.getIntrinsicId())
290+
pinhole = camera.Pinhole.cast(intrinsic)
291+
par = 1.0
292+
if pinhole is not None:
293+
par = pinhole.getPixelAspectRatio()
294+
image_paths.append((image_x1_path, str(id), v.getFrameId(), v.getImage().getWidth(),
295+
v.getImage().getHeight(), par, image_x2_path, image_x4_path))
296+
297+
image_paths.sort(key=lambda x: x[0])
298+
else:
299+
raise ValueError(f"Input path '{input_path}' is not a valid path (folder or sfmData file).")
300+
return image_paths

meshroom/imageSegmentation/VideoSegmentationSam3Text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def processChunk(self, chunk):
354354
boxes[textPrompt]["forward"][firstFrameId + frameId] = {}
355355
for key, maskBoxProb in outputs_per_frame_visu[frameId].items():
356356
mask = maskBoxProb["mask"]
357-
mask_images[frameId][mask] = [255, 255, 255]
357+
mask_images[frameId][mask] = [(int(key) + 1) * 255, 255, 255]
358358
color = colorPalette.at(int(key)) if colorPalette.at(int(key)) is not None else [255, 255, 255]
359359
colorMaskImageFwd[mask] = [x/255.0 for x in color]
360360

@@ -401,7 +401,7 @@ def processChunk(self, chunk):
401401
boxes[textPrompt]["backward"][firstFrameId + frameId] = {}
402402
for key, maskBoxProb in outputs_per_frame_visu[frameId].items():
403403
mask = maskBoxProb["mask"]
404-
mask_images[frameId][mask] = [255, 255, 255]
404+
mask_images[frameId][mask] = [(int(key) + 1) * 255, 255, 255]
405405
color = colorPalette.at(int(key)) if colorPalette.at(int(key)) is not None else [255, 255, 255]
406406
colorMaskImageBwd[mask] = [x/255.0 for x in color]
407407
if chunk.node.outputCryptomatte.value:

0 commit comments

Comments
 (0)