forked from facebookresearch/sam3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwebcam_prompt_highlighter.py
More file actions
81 lines (68 loc) · 2.89 KB
/
webcam_prompt_highlighter.py
File metadata and controls
81 lines (68 loc) · 2.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import base64
import io
import os
from flask import Flask, request, Response, send_file
from PIL import Image, ImageDraw
import torch
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
app = Flask(__name__)
# Load the SAM3 model and processor once at startup
print("Loading SAM3 model...")
model = build_sam3_image_model()
processor = Sam3Processor(model)
print("SAM3 model loaded.")
HTML_PATH = os.path.join(os.path.dirname(__file__), "assets", "webcam_prompt_highlighter.html")
@app.get("/")
def index():
with open(HTML_PATH, "r", encoding="utf-8") as f:
return Response(f.read(), mimetype="text/html")
@app.post("/segment")
def segment():
payload = request.get_json(force=True) or {}
prompt = (payload.get("prompt") or "").strip()
data_url = payload.get("image") or ""
if not prompt:
return {"error": "prompt is required"}, 400
if "," in data_url:
_, data_url = data_url.split(",", 1)
try:
image_bytes = base64.b64decode(data_url)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as exc:
return {"error": f"invalid image: {exc}"}, 400
# --- AI logic: use SAM3 to get masks for the prompt ---
try:
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks = output["masks"] # (N, H, W) numpy array or torch tensor
# If masks is a torch tensor, convert to numpy
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy()
# Handle case where no masks are returned
if masks is not None and len(masks) > 0:
mask_color = (13, 219, 138, 120) # RGBA, semi-transparent green
mask_img = Image.new("RGBA", image.size, (0, 0, 0, 0))
for mask in masks:
# Ensure mask is 2D (H, W)
if mask.ndim == 3:
# Sometimes mask may be (1, H, W)
mask = mask.squeeze()
if mask.ndim != 2:
continue # skip invalid mask
mask_pil = Image.fromarray((mask * 255).astype("uint8"), mode="L").resize(image.size)
color_layer = Image.new("RGBA", image.size, mask_color)
mask_img = Image.composite(color_layer, mask_img, mask_pil)
# Composite mask over original image
image_rgba = image.convert("RGBA")
image_out = Image.alpha_composite(image_rgba, mask_img)
image = image_out.convert("RGB")
# else: no masks, return original image
except Exception as exc:
return {"error": f"model error: {exc}"}, 500
buf = io.BytesIO()
image.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000, debug=False)