-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathpredictor.py
More file actions
490 lines (406 loc) · 20.4 KB
/
predictor.py
File metadata and controls
490 lines (406 loc) · 20.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
# Copyright (C) 2025-2026 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""PyTorch backend implementation for SAM predictor."""
from logging import getLogger
from pathlib import Path
from typing import ClassVar
import numpy as np
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from segment_anything_hq import sam_model_registry
from segment_anything_hq.modeling.prompt_encoder import PositionEmbeddingRandom as _PositionEmbeddingRandom
from segment_anything_hq.modeling.prompt_encoder import PromptEncoder as _PromptEncoder
from segment_anything_hq.predictor import SamPredictor as _SamPredictor
from torch import nn
from instantlearn.data import ResizeLongestSide
from instantlearn.data.utils import read_image
from instantlearn.utils.constants import DATA_PATH, MODEL_MAP, SAMModelName
from instantlearn.utils.optimization import optimize_model
from instantlearn.utils.utils import download_file, precision_to_torch_dtype
logger = getLogger("Geti Instant Learn")
def load_sam_model(
sam: SAMModelName,
device: str = "cuda",
precision: str = "bf16",
compile_models: bool = False,
model_path: Path | None = None,
target_length: int = 1024,
) -> "SAMPredictor":
"""Load and return a SAM predictor with specified backend.
This function provides a unified interface for loading SAM models with
different backends (PyTorch, OpenVINO). The backend parameter determines
which implementation to use.
Args:
sam: The SAM model architecture to load (e.g., SAM_HQ_TINY, SAM2_BASE)
device: Device to run inference on:
- PyTorch backend: "cuda", "cpu"
- OpenVINO backend: "CPU", "GPU", "AUTO"
precision: Model precision for PyTorch backend ("bf16", "fp32", "fp16").
Ignored for OpenVINO backend (precision is baked into IR).
compile_models: Whether to compile model (PyTorch only).
Ignored for OpenVINO backend.
model_path: Optional path to model weights:
- PyTorch: Path to .pth checkpoint (auto-downloads if None)
- OpenVINO: Path to .xml IR file (required)
target_length: Target length for the longest side of the image during transformation. Defaults to 1024.
Returns:
A SAM predictor instance (PyTorchSAMPredictor or OpenVINOSAMPredictor).
Raises:
ValueError: If the model type or backend is invalid.
Examples:
>>> # PyTorch backend with auto-download
>>> predictor = load_sam_model(
... SAMModelName.SAM_HQ_TINY,
... device="cuda",
... )
"""
if sam not in MODEL_MAP:
msg = f"Invalid model type: {sam}"
raise ValueError(msg)
predictor = SAMPredictor(
sam_model_name=sam,
device=device,
model_path=model_path,
target_length=target_length,
)
# Apply PyTorch-specific optimizations
predictor._predictor = optimize_model(
model=predictor._predictor,
device=device,
precision=precision_to_torch_dtype(precision),
compile_models=compile_models,
)
return predictor
def check_model_weights(model_id: str) -> None:
"""Check if model weights exist locally, download if necessary.
Args:
model_id: The model ID (e.g., "sam-hq-tiny", "sam2-base").
Raises:
ValueError: If the model is not found in registry or has no weights URL.
"""
model = get_model(model_id)
if model is None:
valid = [m.id for m in get_models_by_type(ModelType.SEGMENTER)]
msg = f"Model '{model_id}' not found. Valid segmenters: {valid}"
raise ValueError(msg)
if model.weights_url is None:
msg = f"Model '{model_id}' has no weights_url configured."
raise ValueError(msg)
local_filename = get_local_filename(model_id)
target_path = DATA_PATH.joinpath(local_filename)
if not target_path.exists():
msg = f"Model weights for {model_id} not found at {target_path}, downloading..."
logger.info(msg)
download_file(model.weights_url, target_path, model.sha_sum)
class PositionEmbeddingRandom(_PositionEmbeddingRandom):
"""Dtype-aware positional encoding using random spatial frequencies.
This is a drop-in replacement for segment_anything_hq's PositionEmbeddingRandom
that preserves the model's dtype (e.g., bfloat16) instead of forcing float32.
The original implementation hardcodes `coords.to(torch.float)` which causes
dtype mismatch when the model runs in bfloat16 or float16 precision.
See Also:
segment_anything_hq.modeling.prompt_encoder.PositionEmbeddingRandom
"""
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points normalized to [0,1], preserving dtype."""
# Convert coords to match the gaussian matrix dtype
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
coords = 2 * coords - 1
coords @= self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward_with_coords(
self,
coords_input: torch.Tensor,
image_size: tuple[int, int],
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] /= image_size[1]
coords[:, :, 1] /= image_size[0]
return self._pe_encoding(coords)
class PromptEncoder(_PromptEncoder):
"""ONNX-compatible prompt encoder for SAM model export.
This is a drop-in replacement for segment_anything_hq's PromptEncoder
that uses ONNX-traceable operations. Key differences:
- Replaces boolean indexing with element-wise multiplication
- Uses sentinel values (all-zero tensors) for optional inputs
- All operations are pure tensor ops (no .item(), no Python conditionals)
- Uses dtype-aware PositionEmbeddingRandom for bfloat16/float16 support
This encoder works for both PyTorch inference and ONNX/OpenVINO export.
See Also:
segment_anything_hq.modeling.prompt_encoder.PromptEncoder
"""
def __init__(
self,
embed_dim: int,
image_embedding_size: tuple[int, int],
input_image_size: tuple[int, int],
mask_in_chans: int,
activation: type[nn.Module] = nn.GELU,
) -> None:
"""Initialize with dtype-aware PositionEmbeddingRandom."""
super().__init__(embed_dim, image_embedding_size, input_image_size, mask_in_chans, activation)
# Replace pe_layer with dtype-aware version
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
points += 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device, dtype=points.dtype)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device, dtype=labels.dtype)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
# Use ONNX-compatible operations instead of boolean indexing
# Create masks for each label type
mask_neg1 = (labels == -1).float().unsqueeze(-1) # [B, N, 1]
mask_0 = (labels == 0).float().unsqueeze(-1) # [B, N, 1]
mask_1 = (labels == 1).float().unsqueeze(-1) # [B, N, 1]
# Apply embeddings using element-wise multiplication
point_embedding *= 1 - mask_neg1 # Zero out -1 labels
point_embedding += mask_neg1 * self.not_a_point_embed.weight
point_embedding += mask_0 * self.point_embeddings[0].weight
point_embedding += mask_1 * self.point_embeddings[1].weight
return point_embedding
def _get_dtype(self) -> torch.dtype:
return self.point_embeddings[0].weight.dtype
def forward(
self,
points: tuple[torch.Tensor, torch.Tensor] | None,
boxes: torch.Tensor | None,
masks: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""ONNX-traceable forward pass with optional box and mask prompts.
Uses sentinel values (all-zero tensors) to handle optional inputs:
- All-zero boxes → skip box embedding (zero out embeddings)
- All-zero masks → use default no_mask_embed (blend to default)
All operations are pure tensor operations (no .item(), no Python
conditionals on tensor values) to ensure ONNX traceability.
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device(), dtype=self._get_dtype())
if points is not None:
coords, labels = points
# Always pad points when boxes input exists (even if dummy)
# The box embeddings will be masked out later if they're dummy
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
# Embed boxes first
box_embeddings = self._embed_boxes(boxes)
# Detect dummy boxes: check if all coordinates are zero
# box shape is typically [B, 4] or [B, 1, 4]
boxes_flat = boxes.reshape(boxes.shape[0], -1) # [B, 4] or [B, num_boxes*4]
boxes_sum = boxes_flat.abs().sum(dim=1, keepdim=True) # [B, 1]
# Create mask: 1.0 if boxes are valid (non-zero), 0.0 if dummy (all zeros)
has_valid_boxes = (boxes_sum > 0).float() # [B, 1]
has_valid_boxes = has_valid_boxes.unsqueeze(-1) # [B, 1, 1]
has_valid_boxes = has_valid_boxes.expand(-1, box_embeddings.shape[1], -1) # [B, num_boxes, 1]
has_valid_boxes = has_valid_boxes.expand_as(box_embeddings) # [B, num_boxes, embed_dim]
# Zero out box embeddings for dummy boxes (element-wise multiplication)
box_embeddings *= has_valid_boxes
# Concatenate (zeros will be concatenated if boxes were dummy)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
# Compute mask embeddings
mask_embeddings = self._embed_masks(masks)
# Detect dummy masks: check if all values are zero
# mask shape is [B, 1, H, W]
masks_flat = masks.reshape(masks.shape[0], -1) # [B, H*W]
masks_sum = masks_flat.abs().sum(dim=1, keepdim=True) # [B, 1]
# Create mask: 1.0 if masks are valid (non-zero), 0.0 if dummy (all zeros)
has_valid_masks = (masks_sum > 0).to(self._get_dtype()) # [B, 1]
# Get default "no mask" embedding
no_mask_embed = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs,
-1,
self.image_embedding_size[0],
self.image_embedding_size[1],
)
# Blend between mask embeddings and no_mask embeddings
# has_valid_masks: [B, 1] -> [B, 1, 1, 1] -> [B, embed_dim, H, W]
has_valid_masks = has_valid_masks.view(bs, 1, 1, 1) # [B, 1, 1, 1]
has_valid_masks = has_valid_masks.expand_as(mask_embeddings) # [B, embed_dim, H, W]
# If masks are valid, use mask_embeddings; otherwise use no_mask_embed
dense_embeddings = has_valid_masks * mask_embeddings + (1 - has_valid_masks) * no_mask_embed
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs,
-1,
self.image_embedding_size[0],
self.image_embedding_size[1],
)
return sparse_embeddings, dense_embeddings
class SAMPredictor(nn.Module):
"""PyTorch implementation of SAM predictor.
This implementation wraps the original SAM predictor from segment_anything_hq
and SAM2 predictors, providing a unified interface while delegating to the
appropriate backend predictor.
The prompt encoder is patched with an ONNX-compatible version that uses
element-wise operations instead of boolean indexing, enabling both efficient
PyTorch inference and seamless export to ONNX/OpenVINO formats.
Note:
**Optional Prompts**: When using exported models, boxes and mask_input
support "not provided" scenarios using sentinel values:
- **Boxes**: Pass all-zero boxes (e.g., [[0, 0, 0, 0]]) to indicate "no boxes".
The prompt encoder detects these and zeros out box embeddings.
- **Mask Input**: Pass all-zero masks (e.g., zeros((B, 1, 256, 256))) to indicate
"no mask input". The prompt encoder detects these and uses the default no_mask_embed.
"""
# SAM-HQ registry name mapping (model_id -> segment_anything_hq registry key)
SAM_HQ_REGISTRY_MAP: ClassVar[dict[str, str]] = {
"sam-hq": "vit_h",
"sam-hq-tiny": "vit_tiny",
}
def __init__(
self,
model_id: str,
device: str,
model_path: Path | None = None,
target_length: int = 1024,
) -> None:
"""Initialize SAM predictor.
Args:
sam_model_name: The SAM model architecture (e.g., SAM_HQ_TINY, SAM2_BASE)
device: Device to run inference on ("cuda", "cpu")
model_path: Path to .pth checkpoint file (optional, auto-downloads if None)
target_length: Target length for the longest side of the image during transformation. Defaults to 1024.
Raises:
ValueError: If the model ID is not found in registry.
"""
super().__init__()
self.device = device
self._sam_model_name = sam_model_name
self.transform = ResizeLongestSide(target_length)
self._original_size: tuple[int, int] | None = None
# Determine checkpoint path
if model_path is None:
check_model_weights(model_id)
local_filename = get_local_filename(model_id)
checkpoint_path = DATA_PATH.joinpath(local_filename)
else:
checkpoint_path = model_path
msg = f"Loading PyTorch SAM: {model_id} from {checkpoint_path}"
logger.info(msg)
# Load model based on family
if model_meta.family == "SAM2":
config_path = "configs/sam2.1/" + model_meta.config_filename
sam_model = build_sam2(config_path, str(checkpoint_path))
self._predictor = SAM2ImagePredictor(sam_model)
elif model_meta.family == "SAM-HQ":
registry_name = self.SAM_HQ_REGISTRY_MAP.get(model_id)
if registry_name is None:
msg = f"SAM-HQ model '{model_id}' not in SAM_HQ_REGISTRY_MAP"
raise ValueError(msg)
sam_model = sam_model_registry[registry_name]().to(device)
# suppress - loading the snapshot from the local path
# nosemgrep trailofbits.python.pickles-in-pytorch.pickles-in-pytorch
state_dict = torch.load(checkpoint_path, map_location=device) # nosec: B614
info = sam_model.load_state_dict(state_dict, strict=False)
if info.missing_keys:
msg = f"Missing keys when loading SAM-HQ model: {info.missing_keys}"
logger.warning(msg)
if info.unexpected_keys:
msg = f"Unexpected keys when loading SAM-HQ model: {info.unexpected_keys}"
logger.warning(msg)
sam_model.eval()
self._predictor = _SamPredictor(sam_model)
# Patch with ONNX-compatible prompt encoder for SAM-HQ models
self._patch_prompt_encoder(device)
self._freeze_modules([
self._predictor.model.mask_decoder,
self._predictor.model.prompt_encoder,
self._predictor.model.image_encoder,
])
else:
msg = f"Model family '{model_meta.family}' not implemented"
raise NotImplementedError(msg)
def _patch_prompt_encoder(self, device: str) -> None:
"""Replace prompt encoder with ONNX-compatible version.
This patches the SAM-HQ model's prompt encoder to use element-wise
operations instead of boolean indexing, enabling ONNX/OpenVINO export.
The patched encoder also supports bfloat16/float16 precision.
"""
original_encoder = self._predictor.model.prompt_encoder
patched_encoder = PromptEncoder(
embed_dim=original_encoder.embed_dim,
image_embedding_size=original_encoder.image_embedding_size,
input_image_size=original_encoder.input_image_size,
mask_in_chans=16, # It's always 16
)
# Load weights from original encoder (preserves original dtype)
patched_encoder.load_state_dict(original_encoder.state_dict(), strict=True)
patched_encoder.to(device)
self._predictor.model.prompt_encoder = patched_encoder
def set_image(self, image: torch.Tensor | str | Path) -> None:
"""Set image using PyTorch backend.
Transforms the image to the target size and delegates to the underlying
predictor's set_torch_image method, which computes and caches image
embeddings for efficient inference.
Args:
image: Raw image tensor of shape (C, H, W), or path to image file.
"""
if isinstance(image, (str, Path)):
image = read_image(image)
self._original_size = image.shape[-2:]
transformed_image = self.transform.apply_image_torch(image).to(device=self.device)
return self._predictor.set_torch_image(transformed_image, self._original_size)
@staticmethod
def _freeze_modules(modules: list[nn.Module]) -> None:
"""Freeze the modules."""
for module in modules:
for p in module.parameters():
p.requires_grad_(requires_grad=False)
@torch.inference_mode()
def forward(
self,
point_coords: torch.Tensor | None,
point_labels: torch.Tensor | None,
boxes: torch.Tensor | None = None,
mask_input: torch.Tensor | None = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict masks using PyTorch backend.
Transforms point coordinates and boxes to the target image size,
then delegates to the underlying predictor's predict_torch method.
Args:
point_coords: Point coordinates [B, N, 2] in (x, y) format (original image coordinates)
point_labels: Point labels [B, N] (1=foreground, 0=background, -1=padding)
boxes: Box prompts [B, 4] or [B, 1, 4] in (x1, y1, x2, y2) format (original image coordinates)
mask_input: Low-res mask input [B, 1, 256, 256]
multimask_output: Whether to return multiple masks
return_logits: Whether to return logits instead of binary masks
Returns:
Tuple of (masks, iou_predictions, low_res_logits)
Raises:
RuntimeError: If set_image() was not called before predict()
"""
if self._original_size is None:
msg = "Must call set_image() before predict()"
raise RuntimeError(msg)
# Transform point coordinates to target image size
transformed_point_coords = None
if point_coords is not None:
# point_coords shape: [B, N, 2]
original_shape = point_coords.shape
coords_flat = point_coords.reshape(-1, 2)
transformed_coords = self.transform.apply_coords_torch(coords_flat, self._original_size)
transformed_point_coords = transformed_coords.reshape(original_shape)
# Transform boxes to target image size
transformed_boxes = None
if boxes is not None:
# boxes shape: [B, 4] or [B, 1, 4]
original_shape = boxes.shape
boxes_flat = boxes.reshape(-1, 4)
transformed_boxes_flat = self.transform.apply_boxes_torch(boxes_flat, self._original_size)
transformed_boxes = transformed_boxes_flat.reshape(original_shape)
return self._predictor.predict_torch(
point_coords=transformed_point_coords,
point_labels=point_labels,
boxes=transformed_boxes,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
)