Skip to content

Commit a2e14e3

Browse files
committed
Add color palette management and cryptomatte generation
1 parent 5a2f9db commit a2e14e3

2 files changed

Lines changed: 140 additions & 44 deletions

File tree

meshroom/imageSegmentation/VideoSegmentationSam3.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import os
44
from pathlib import Path
5+
import struct
56

67
from meshroom.core import desc
78
from meshroom.core.utils import VERBOSE_LEVEL
89

10+
import logging
11+
logger = logging.getLogger("VideoSegmentationSam3")
12+
913
class Sam3VideoNodeSize(desc.MultiDynamicNodeSize):
1014
def computeSize(self, node):
1115
size = 1
@@ -17,31 +21,10 @@ class VideoSegmentationSam3(desc.Node):
1721

1822
category = "Utils"
1923
documentation = """
20-
Based on the Segment Anything video predictor model 3, the node generates a binary mask from a text prompt,
21-
a single bounding box or a set of positive and negative clicks (Clicks In/Out).
22-
Two masks are generated, a binary one and a colored one that the indexes of every submasks.
23-
Object Ids are color encoded as follow:
24-
0:[1,0,0] = 0xff0000
25-
1:[0,1,0] = 0x00ff00
26-
2:[0,0,1] = 0x0000ff
27-
3:[1,1,0] = 0xffff00
28-
4:[1,0,1] = 0xff00ff
29-
5:[0,1,1] = 0x00ffff
30-
6:[1,0,0.5] = 0xff0080
31-
7:[0,1,0.5] = 0x00ff80
32-
8:[0,0.5,1] = 0x0080ff
33-
9:[1,1,0.5] = 0xffff80
34-
10:[1,0.5,1] = 0xff80ff
35-
11:[0.5,1,1] = 0x80ffff
36-
12:[1,0.5,0] = 0xff8000
37-
13:[0.5,1,0] = 0x80ff00
38-
14:[0.5,0,1] = 0x8000ff
39-
15:[1,0.5,0.5] = 0xff8080
40-
16:[0.5,1,0.5] = 0x80ff80
41-
17:[0.5,0.5,1] = 0x8080ff
42-
18:[1,1,1] = 0xffffff
43-
After that, refinement is possible through in/out points for every segmented objects.
44-
In order to associate a point to a given submask, it must be colored with the corresponding color.
24+
Based on the Segment Anything video predictor model 3, the node generates a binary mask, a colored mask and an exr cryptomatte
25+
from a text prompt, a single bounding box or a set of positive and negative clicks (Clicks In/Out).
26+
Text prompt and Clicks can be combined to refine results. For refinement, points must be associated to an existing submask.
27+
In order to associate a point to a given submask, it must be colored with the submask's color.
4528
"""
4629

4730
inputs = [
@@ -87,6 +70,12 @@ class VideoSegmentationSam3(desc.Node):
8770
description="Invert mask values. If selected, the pixels corresponding to the mask will be set to 0 instead of 255.",
8871
value=False,
8972
),
73+
desc.BoolParam(
74+
name="outputCryptomatte",
75+
label="Output Cryptomatte",
76+
description="Generate exr images containing cryptomatte to encode the segmentation results.",
77+
value=False,
78+
),
9079
desc.BoolParam(
9180
name="useGpu",
9281
label="Use GPU",
@@ -175,6 +164,15 @@ class VideoSegmentationSam3(desc.Node):
175164
value=lambda attr: "{nodeCacheFolder}/colorMask_" + ("<FILESTEM>" if attr.node.keepFilename.value else "<VIEW_ID>") + ".png",
176165
group="",
177166
),
167+
desc.File(
168+
name="cryptomatte",
169+
label="Cryptomatte",
170+
description="Cryptomatte embeded in exr images.",
171+
semantic="image",
172+
value=lambda attr: "{nodeCacheFolder}/cryptomatte_" + ("<FILESTEM>" if attr.node.keepFilename.value else "<VIEW_ID>") + ".exr",
173+
enabled=lambda node: node.outputCryptomatte.value,
174+
group="",
175+
),
178176
]
179177

180178
def prepare_masks_for_visualization(self, frame_to_output):
@@ -270,6 +268,14 @@ def normalize_bbox(self, bbox_xywh, img_w, img_h, PAR, orientation):
270268
normalized_bbox[..., 3] /= img_h
271269
return normalized_bbox
272270

271+
def hash_name(self, name):
272+
import mmh3
273+
import numpy as np
274+
hash_32 = mmh3.hash(name, seed=0) & 0xFFFFFFFF
275+
f32_val = np.frombuffer(struct.pack('<I',hash_32), dtype=np.float32)[0]
276+
f32_hex = hex(struct.unpack('<I', struct.pack('<f', f32_val))[0])[2:]
277+
return f32_val, f32_hex, hash_32
278+
273279

274280
def preprocess(self, node):
275281
extension = node.extensionIn.value
@@ -288,17 +294,18 @@ def processChunk(self, chunk):
288294
import torch
289295
from pyalicevision import image as avimg
290296
from PIL import Image
297+
import OpenImageIO as oiio
291298

292299
try:
293-
chunk.logManager.start(chunk.node.verboseLevel.value)
300+
logger.setLevel(chunk.node.verboseLevel.value.upper())
294301

295302
if not chunk.node.input:
296-
chunk.logger.warning("Nothing to segment")
303+
logger.warning("Nothing to segment")
297304
return
298305
if not chunk.node.output.value:
299306
return
300307

301-
chunk.logger.info("Chunk range from {} to {}".format(chunk.range.start, chunk.range.last))
308+
logger.info("Chunk range from {} to {}".format(chunk.range.start, chunk.range.last))
302309

303310
chunk_image_paths = self.image_paths
304311

@@ -320,9 +327,7 @@ def processChunk(self, chunk):
320327
clicks = {}
321328
bboxes = {}
322329

323-
colors=[[255,0,0],[0,255,0],[0,0,255],[255,255,0],[255,0,255],[0,255,255],
324-
[255,0,128],[0,255,128],[0,128,255],[255,255,128],[255,128,255],[128,255,255],
325-
[255,128,0],[128,255,0],[128,0,255],[255,128,128],[128,255,128],[128,128,255],[255,255,255]]
330+
colorPalette = image.paletteGenerator()
326331

327332
for idx, path in enumerate(chunk_image_paths):
328333
img, h_ori, w_ori, PAR, orientation = image.loadImage(str(chunk_image_paths[idx][0]), True)
@@ -335,10 +340,10 @@ def processChunk(self, chunk):
335340
objects = {}
336341
if viewId is not None and viewId in posClickDictFromShape:
337342
for pt in posClickDictFromShape[viewId]:
338-
color = [int(pt[1][1:3], 16), int(pt[1][3:5], 16), int(pt[1][5:], 16)]
339-
if color not in colors:
340-
colors.append(color)
341-
objId = colors.index(color)
343+
color = (int(pt[1][1:3], 16), int(pt[1][3:5], 16), int(pt[1][5:], 16))
344+
if colorPalette.index(color) is None:
345+
colorPalette.add_color(color)
346+
objId = colorPalette.index(color)
342347

343348
if objId not in objects:
344349
objects[objId] = [[], []]
@@ -349,10 +354,10 @@ def processChunk(self, chunk):
349354

350355
if viewId is not None and viewId in negClickDictFromShape:
351356
for pt in negClickDictFromShape[viewId]:
352-
color = [int(pt[1][1:3], 16), int(pt[1][3:5], 16), int(pt[1][5:], 16)]
353-
if color not in colors:
354-
colors.append(color)
355-
objId = colors.index(color)
357+
color = (int(pt[1][1:3], 16), int(pt[1][3:5], 16), int(pt[1][5:], 16))
358+
if colorPalette.index(color) is None:
359+
colorPalette.add_color(color)
360+
objId = colorPalette.index(color)
356361

357362
if objId not in objects:
358363
objects[objId] = [[], []]
@@ -372,8 +377,8 @@ def processChunk(self, chunk):
372377
bboxes[frameId][0].append(bbox)
373378
bboxes[frameId][1].append(1)
374379

375-
chunk.logger.debug(f"clicks = {clicks}")
376-
chunk.logger.debug(f"bboxes = {bboxes}")
380+
logger.debug(f"clicks = {clicks}")
381+
logger.debug(f"bboxes = {bboxes}")
377382

378383
response = video_predictor.handle_request(
379384
request=dict(
@@ -427,15 +432,51 @@ def processChunk(self, chunk):
427432
for frameId, masks in outputs_per_frame.items():
428433
maskImage = np.zeros_like(img)
429434
colorMaskImage = np.zeros_like(img)
435+
if chunk.node.outputCryptomatte.value:
436+
crypto_id = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32)
437+
crypto_cov = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32)
438+
crypto_zeros = np.zeros((img.shape[0], img.shape[1]), dtype=np.float32)
439+
manifest = {}
440+
441+
colorPalette.generate_palette(max(masks.keys()) + 1)
442+
cryptoName = "object" if chunk.node.prompt.value=="" else chunk.node.prompt.value
430443
for key, mask in masks.items():
431444
maskImage[mask] = [255, 255, 255]
432-
colorMaskImage[mask] = [x/255.0 for x in colors[int(key) % len(colors)]]
445+
colorMaskImage[mask] = [x/255.0 for x in colorPalette.at(int(key))]
446+
if chunk.node.outputCryptomatte.value:
447+
obj_name = f"{cryptoName}_{int(key)}"
448+
f32_hash, hex_val, _ = self.hash_name(obj_name)
449+
manifest[obj_name] = hex_val
450+
crypto_id[mask] = f32_hash
451+
crypto_cov[mask] = 1.0
452+
453+
if chunk.node.outputCryptomatte.value:
454+
spec = oiio.ImageSpec(img.shape[1], img.shape[0], 7, oiio.FLOAT)
455+
spec.channelnames = (cryptoName+".red", cryptoName+".green", cryptoName+".blue",
456+
cryptoName+"00.red", cryptoName+"00.green", cryptoName+"00.blue", cryptoName+"00.alpha")
457+
_, _, h32 = self.hash_name(cryptoName)
458+
key = f"{h32 & 0xFFFFFFFF:08x}"[:7]
459+
spec.attribute(f"cryptomatte/{key}/name", cryptoName)
460+
spec.attribute(f"cryptomatte/{key}/manifest", json.dumps(manifest))
461+
spec.attribute(f"cryptomatte/{key}/hash", "MurmurHash3_32")
462+
spec.attribute(f"cryptomatte/{key}/conversion", "uint32_to_float32")
463+
464+
if chunk.node.keepFilename.value:
465+
cryptomattePath = os.path.join(chunk.node.output.value, "cryptomatte_" + str(Path(chunk_image_paths[frameId][0]).stem) + ".exr")
466+
else:
467+
cryptomattePath = os.path.join(chunk.node.output.value, "cryptomatte_" + str(chunk_image_paths[frameId][1]) + ".exr")
468+
469+
cryptomatteImg = oiio.ImageOutput.create(str(cryptomattePath))
470+
cryptomatteImg.open(cryptomattePath, spec)
471+
cryptomatte_data = np.dstack((crypto_zeros, crypto_zeros, crypto_zeros, crypto_id, crypto_cov, crypto_zeros, crypto_zeros))
472+
cryptomatteImg.write_image(cryptomatte_data)
473+
cryptomatteImg.close()
433474

434475
if chunk.node.maskInvert.value:
435476
mask = (maskImage[:,:,0:1] == 0).astype('float32')
436477
else:
437478
mask = (maskImage[:,:,0:1] > 0).astype('float32')
438-
chunk.logger.info("frameId: {} - {}".format(frameId, chunk_image_paths[frameId][0]))
479+
logger.info("frameId: {} - {}".format(frameId, chunk_image_paths[frameId][0]))
439480

440481
if chunk.node.keepFilename.value:
441482
outputFileMask = os.path.join(chunk.node.output.value, Path(chunk_image_paths[frameId][0]).stem + "." + chunk.node.extensionOut.value)
@@ -455,7 +496,6 @@ def processChunk(self, chunk):
455496

456497
finally:
457498
torch.cuda.empty_cache()
458-
chunk.logManager.end()
459499

460500

461501
def get_image_paths_list(input_path, extension):

segmentationRDS/image.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
import random
13
import os
24
import numpy as np
35
import OpenImageIO as oiio
@@ -229,3 +231,57 @@ def addText(image: np.ndarray, text, x, y, size, color = (255, 0, 0)) -> np.ndar
229231
oiio.ImageBufAlgo.render_text(buf, int(x), int(y), text, int(size), "", color)
230232
return buf.get_pixels(format='uint8')
231233

234+
class paletteGenerator:
235+
def __init__(self, seed=42):
236+
self.seed = seed
237+
self.colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
238+
239+
def _dist(self, c1, c2):
240+
return math.sqrt(sum((a - b) ** 2 for a, b in zip(c1, c2)))
241+
242+
def _is_grey(self, color, threshold=50):
243+
return (max(color) - min(color)) < threshold
244+
245+
def add_color(self, color=None):
246+
if color:
247+
self.colors.append(color)
248+
return color
249+
250+
random.seed = self.seed + len(self.colors)
251+
best_color = None
252+
max_dist_min = -1
253+
254+
for _ in range (2000):
255+
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
256+
if self._is_grey(color):
257+
continue
258+
dist_min = min(self._dist(color, c) for c in self.colors)
259+
if dist_min > max_dist_min:
260+
max_dist_min = dist_min
261+
best_color = color
262+
263+
self.colors.append(best_color)
264+
return best_color
265+
266+
def generate_palette(self, n):
267+
268+
if n < 0:
269+
return None
270+
if n <= len(self.colors):
271+
return self.colors[0:n]
272+
273+
missing = n - len(self.colors)
274+
for _ in range(missing):
275+
self.add_color()
276+
return self.colors
277+
278+
def index(self, color):
279+
if color in self.colors:
280+
return self.colors.index(color)
281+
return None
282+
283+
def at(self, idx):
284+
if idx >=0 and idx < len(self.colors):
285+
c = [x for x in self.colors[idx]]
286+
return c
287+
return None

0 commit comments

Comments
 (0)