1+ __version__ = "1.0"
2+
3+ from functools import total_ordering
4+ import os
5+ from pathlib import Path
6+
7+ from meshroom .core import desc
8+ from meshroom .core .utils import VERBOSE_LEVEL
9+ from pyalicevision import parallelization as avpar
10+
11+ import logging
12+ logger = logging .getLogger ("VideoSegmentationSam3Boxes" )
13+
14+ class VideoSegmentationSam3Boxes (desc .Node ):
15+ size = avpar .DynamicViewsSize ("input" )
16+ gpu = desc .Level .EXTREME
17+
18+ category = "Segmentation"
19+ documentation = """
20+ Based on the Segment Anything video predictor model 3, the node generates binary masks from a set of
21+ bounding boxes contained in a json file.
22+ """
23+
24+ inputs = [
25+ desc .File (
26+ name = "input" ,
27+ label = "Input" ,
28+ description = "SfMData file." ,
29+ value = "" ,
30+ ),
31+ desc .File (
32+ name = "inputx2" ,
33+ label = "Inputx2" ,
34+ description = "Folder containing source images upscale by 2." ,
35+ value = "" ,
36+ ),
37+ desc .File (
38+ name = "inputx4" ,
39+ label = "Inputx4" ,
40+ description = "Folder containing source images upscale by 4." ,
41+ value = "" ,
42+ ),
43+ desc .File (
44+ name = "masksFolder" ,
45+ label = "Masks Folder" ,
46+ description = "Folder containing the masks computed at original resolution." ,
47+ value = "" ,
48+ ),
49+ desc .File (
50+ name = "bboxesFolder" ,
51+ label = "Bounding Boxes Folder" ,
52+ description = "Folder containing the bboxes.json file associated to the sfmData used as input." ,
53+ value = "" ,
54+ ),
55+ desc .File (
56+ name = "segmentationModelPath" ,
57+ label = "Segmentation Model" ,
58+ description = "Weights file for the segmentation model." ,
59+ value = "${RDS_SAM3_MODEL_PATH}" ,
60+ ),
61+ desc .BoolParam (
62+ name = "maskInvert" ,
63+ label = "Invert Masks" ,
64+ description = "Invert mask values. If selected, the pixels corresponding to the mask will be set to 0 instead of 255." ,
65+ value = False ,
66+ ),
67+ desc .BoolParam (
68+ name = "useGpu" ,
69+ label = "Use GPU" ,
70+ description = "Use GPU for computation if available." ,
71+ value = True ,
72+ invalidate = False ,
73+ ),
74+ desc .BoolParam (
75+ name = "keepFilename" ,
76+ label = "Keep Filename" ,
77+ description = "Keep the filename of the inputs for the outputs." ,
78+ value = True ,
79+ ),
80+ desc .ChoiceParam (
81+ name = "extensionOut" ,
82+ label = "Output File Extension" ,
83+ description = "Output image file extension.\n "
84+ "If unset, the output file extension will match the input's if possible." ,
85+ value = "exr" ,
86+ values = ["exr" , "png" , "jpg" ],
87+ exclusive = True ,
88+ ),
89+ desc .ChoiceParam (
90+ name = "verboseLevel" ,
91+ label = "Verbose Level" ,
92+ description = "Verbosity level (fatal, error, warning, info, debug)." ,
93+ value = "info" ,
94+ values = VERBOSE_LEVEL ,
95+ exclusive = True ,
96+ ),
97+ ]
98+
99+ outputs = [
100+ desc .File (
101+ name = "output" ,
102+ label = "Masks Folder" ,
103+ description = "Output path for the masks." ,
104+ value = "{nodeCacheFolder}" ,
105+ ),
106+ desc .File (
107+ name = "masks" ,
108+ label = "Masks" ,
109+ description = "Generated segmentation masks." ,
110+ semantic = "image" ,
111+ value = lambda attr : "{nodeCacheFolder}/" + ("<FILESTEM>" if attr .node .keepFilename .value else "<VIEW_ID>" ) + "." + attr .node .extensionOut .value ,
112+ ),
113+ ]
114+
115+ def preprocess (self , node ):
116+ import re
117+ input_path = node .input .value
118+ image_paths = get_image_paths_list (input_path , node .inputx2 .value , node .inputx4 .value )
119+ if len (image_paths ) == 0 :
120+ raise FileNotFoundError (f'No image files found in { input_path } ' )
121+ self .image_paths = image_paths
122+ if node .bboxesFolder .value == "" :
123+ raise ValueError (f'No file containing bounding boxes connected' )
124+
125+ def processChunk (self , chunk ):
126+ from segmentationRDS import image , sam3Utils , bboxUtils
127+ from sam3 .model_builder import build_sam3_video_predictor
128+ import numpy as np
129+ import torch
130+ from pyalicevision import image as avimg
131+ from PIL import Image
132+ import OpenImageIO as oiio
133+
134+ try :
135+ logger .setLevel (chunk .node .verboseLevel .value .upper ())
136+
137+ if not chunk .node .input :
138+ logger .warning ("Nothing to segment" )
139+ return
140+ if not chunk .node .output .value :
141+ return
142+
143+ logger .info ("Chunk range from {} to {}" .format (chunk .range .start , chunk .range .last ))
144+
145+ chunk_image_paths = self .image_paths
146+
147+ if not os .path .exists (chunk .node .output .value ):
148+ os .mkdir (chunk .node .output .value )
149+
150+ gpus_to_use = [torch .cuda .current_device ()]
151+ video_predictor = build_sam3_video_predictor (checkpoint_path = chunk .node .segmentationModelPath .evalValue , gpus_to_use = gpus_to_use )
152+
153+ metadata_deep_model = {}
154+ metadata_deep_model ["Meshroom:mrSegmentation:DeepModelName" ] = "SegmentAnything"
155+ metadata_deep_model ["Meshroom:mrSegmentation:DeepModelVersion" ] = "sam3-Video-Crop"
156+
157+ # bboxes.json decoding
158+ json_path = os .path .join (chunk .node .bboxesFolder .value , "bboxes.json" )
159+ frame_w = chunk_image_paths [0 ][3 ]
160+ frame_h = chunk_image_paths [0 ][4 ]
161+ par = chunk_image_paths [0 ][5 ]
162+ x2_ok = os .path .exists (chunk .node .inputx2 .value )
163+ x4_ok = os .path .exists (chunk .node .inputx4 .value )
164+ bboxes = bboxUtils .extract_tracking (json_path , frame_w , frame_h , x2_ok , x4_ok , par )
165+
166+ logger .info (f"bboxes.keys() = { bboxes .keys ()} " )
167+
168+ full_mask_images = {}
169+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [0 ][0 ]), True )
170+ sourceInfo = {"h_ori" : h_ori , "w_ori" : w_ori , "PAR" : p_a_r , "orientation" : orientation }
171+ for frameId , image_path in enumerate (chunk_image_paths ):
172+ full_mask_images [image_path [2 ]] = np .zeros_like (img )
173+
174+ for key , frame_chunks in bboxes .items ():
175+
176+ textPrompt = key .split ('_' )[0 ]
177+ obj_id = key .split ('_' )[1 ]
178+ logger .info (f"key = { key } ; text prompt = { textPrompt } ; obj_id = { obj_id } " )
179+
180+ for frame_chunk in frame_chunks :
181+ logger .info (frame_chunk )
182+ pil_images = []
183+ mask_images = []
184+ firstFrameId = frame_chunk .start_frame
185+ for frame_idx , box in sorted (frame_chunk .boxes .items ()):
186+ x1 , y1 , x2 , y2 = box
187+ box_w = x2 - x1
188+ box_h = y2 - y1
189+
190+ if box_w == 252 and box_h == 252 :
191+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][7 ]), True )
192+ imgBuf = oiio .ImageBuf (img )
193+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (4 * x1 , 4 * x2 , 4 * y1 , 4 * y2 ))
194+ elif box_w == 504 and box_h == 504 :
195+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][6 ]), True )
196+ imgBuf = oiio .ImageBuf (img )
197+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (2 * x1 , 2 * x2 , 2 * y1 , 2 * y2 ))
198+ else :
199+ img , h_ori , w_ori , p_a_r , orientation = image .loadImage (str (chunk_image_paths [frame_idx - firstFrameId ][0 ]), True )
200+ imgBuf = oiio .ImageBuf (img )
201+ imgBuf = oiio .ImageBufAlgo .crop (imgBuf , roi = oiio .ROI (x1 , x2 , y1 , y2 ))
202+
203+ img_crop = imgBuf .get_pixels (format = oiio .FLOAT )
204+ pil_images .append (Image .fromarray ((255.0 * img_crop ).astype ("uint8" )))
205+ mask_images .append (np .zeros_like (img_crop ))
206+
207+ response = video_predictor .handle_request (
208+ request = dict (
209+ type = "start_session" ,
210+ resource_path = pil_images ,
211+ )
212+ )
213+ session_id = response ["session_id" ]
214+
215+ video_predictor .handle_request (
216+ request = dict (
217+ type = "add_prompt" ,
218+ session_id = session_id ,
219+ frame_index = 0 ,
220+ text = textPrompt ,
221+ )
222+ )
223+ outputs_per_frame = sam3Utils .propagateInVideo (video_predictor , session_id ) #, fIdx, max_frame_num_to_track, track_dir)
224+ outputs_per_frame_visu = sam3Utils .prepareMasksForVisualization (outputs_per_frame )
225+
226+ for frame_idx , box in sorted (frame_chunk .boxes .items ()):
227+ x1 , y1 , x2 , y2 = box
228+ box_w = x2 - x1
229+ box_h = y2 - y1
230+ frameId = frame_idx - firstFrameId
231+ for key , maskBoxProb in outputs_per_frame_visu [frameId ].items ():
232+ mask = maskBoxProb ["mask" ]
233+ buf_in = oiio .ImageBuf (mask .astype ('float32' ))
234+ buf_out = oiio .ImageBufAlgo .resample (buf_in , roi = oiio .ROI (0 , box_w , 0 , box_h ))
235+ mask = buf_out .get_pixels ().reshape (box_h , box_w , 1 )
236+ tgt = full_mask_images [frame_idx ][y1 :y2 ,x1 :x2 , :]
237+ bool_mask = mask .squeeze () > 0
238+ tgt [bool_mask ] = [255 , 255 , 255 ]
239+
240+ video_predictor .handle_request (request = dict (type = "close_session" , session_id = session_id ))
241+
242+
243+ for frameId , image_path in enumerate (chunk_image_paths ):
244+ if chunk .node .maskInvert .value :
245+ mask = (full_mask_images [image_path [2 ]][:,:,0 :1 ] == 0 ).astype ('float32' )
246+ else :
247+ mask = (full_mask_images [image_path [2 ]][:,:,0 :1 ] > 0 ).astype ('float32' )
248+ logger .info (f"frameId: { frameId } - { image_path [0 ]} " )
249+
250+ if chunk .node .keepFilename .value :
251+ outputFileMask = os .path .join (chunk .node .output .value , Path (image_path [0 ]).stem + "." + chunk .node .extensionOut .value )
252+ else :
253+ outputFileMask = os .path .join (chunk .node .output .value , str (image_path [1 ]) + "." + chunk .node .extensionOut .value )
254+
255+ optWrite = avimg .ImageWriteOptions ()
256+ optWrite .toColorSpace (avimg .EImageColorSpace_NO_CONVERSION )
257+ if Path (outputFileMask ).suffix .lower () == ".exr" :
258+ optWrite .exrCompressionMethod (avimg .EImageExrCompression_stringToEnum ("DWAA" ))
259+ optWrite .exrCompressionLevel (300 )
260+
261+ image .writeImage (outputFileMask , mask , sourceInfo ["h_ori" ], sourceInfo ["w_ori" ], sourceInfo ["orientation" ],
262+ sourceInfo ["PAR" ], metadata_deep_model , optWrite )
263+
264+ finally :
265+ torch .cuda .empty_cache ()
266+
267+
268+ def get_image_paths_list (input_path , path_folder_x2 = "" , path_folder_x4 = "" ):
269+ from pyalicevision import sfmData , camera
270+ from pyalicevision import sfmDataIO
271+ from pathlib import Path
272+
273+ image_paths = []
274+
275+ if Path (input_path ).suffix .lower () in [".sfm" , ".abc" ]:
276+ if Path (input_path ).exists ():
277+ dataAV = sfmData .SfMData ()
278+ if sfmDataIO .load (dataAV , input_path , sfmDataIO .ALL ):
279+ views = dataAV .getViews ()
280+ for id , v in views .items ():
281+ image_x1_path = Path (v .getImage ().getImagePath ())
282+ image_x1_name = image_x1_path .name
283+ image_x2_path = None
284+ if os .path .isfile (os .path .join (path_folder_x2 , image_x1_name )):
285+ image_x2_path = os .path .join (path_folder_x2 , image_x1_name )
286+ image_x4_path = None
287+ if os .path .isfile (os .path .join (path_folder_x4 , image_x1_name )):
288+ image_x4_path = os .path .join (path_folder_x4 , image_x1_name )
289+ intrinsic = dataAV .getIntrinsicSharedPtr (v .getIntrinsicId ())
290+ pinhole = camera .Pinhole .cast (intrinsic )
291+ par = 1.0
292+ if pinhole is not None :
293+ par = pinhole .getPixelAspectRatio ()
294+ image_paths .append ((image_x1_path , str (id ), v .getFrameId (), v .getImage ().getWidth (),
295+ v .getImage ().getHeight (), par , image_x2_path , image_x4_path ))
296+
297+ image_paths .sort (key = lambda x : x [0 ])
298+ else :
299+ raise ValueError (f"Input path '{ input_path } ' is not a valid path (folder or sfmData file)." )
300+ return image_paths
0 commit comments