Skip to content

Commit 65e9402

Browse files
Samhq model addition (#35147)
* added the configuartion for sam_hq * added the modeelling for sam_hq * added the sam hq mask decoder with hq features * added the code for the samhq * added the code for the samhq * added the code for the samhq * Delete src/transformers/models/sam_hq/modelling_sam_hq.py * added the code for the samhq * added the code for the samhq * added the chnages for the modeelling * added the code for sam hq for image processing * added code for the sam hq model * added the required changes * added the changes * added the key mappings for the sam hq * adding the working code of samhq * added the required files * adding the pt object * added the push to hub account * added the args for the sam maks decoder * added the args for the sam hq vision config * aded the some more documentation * removed the unecessary spaces * all required chnages * removed the image processor * added the required file * added the changes for the checkcopies * added the code for modular file * added the changes for the __init file * added the code for the interm embeds * added the code for sam hq * added the changes for modular file * added the test file * added the changes required * added the changes required * added the code for the * added the cl errors * added the changes * added the required changes * added the some code * added the code for the removing image processor * added the test dimensins * added the code for the removing extra used variables * added the code for modeluar file hf_mlp for a better name * removed abbrevaation in core functionality * removed abbrevaation in core functionality * .contiguous() method is often used to ensure that the tensor is stored in a contiguous block of memory * added the code which is after make fixup * added some test for the intermediate embeddings test * added the code for the torch support in sam hq * added the code for the updated modular file * added the changes for documentations as mentioned * removed the heading * add the changes for the code * first mentioned issue resolved * added the changes code to processor * added the easy loading to init file * added the changes to code * added the code to changes * added the code to work * added the code for sam hq * added the code for sam hq * added the code for the point pad value * added the small test for the image embeddings and intermediate embedding * added the code * added the code * added the code for the tests * added the code * added ythe code for the processor file * added the code * added the code * added the code * added the code * added the code * added the code for tests and some checks * added some code * added the code * added the code * added some code * added some code * added the changes for required * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added some changes * added some changes * removed spaces and quality checks * added some code * added some code * added some code * added code quality checks * added the checks for quality checks * addded some code which fixes test_inference_mask_generation_no_point * added code for the test_inference_mask_generation_one_point_one_bb * added code for the test_inference_mask_generation_one_point_one_bb_zero * added code for the test_inference_mask_generation_one_box * added some code in modelling for testing * added some code which sort maks with high score * added some code * added some code * added some code for the move KEYS_TO_MODIFY_MAPPING * added some code for the unsqueeze removal * added some code for the unsqueeze removal * added some code * added some code * add some code * added some code * added some code * added some testign values changed * added changes to code in sam hq for readbility purpose * added pre commit checks * added the fix samvisionmodel for compatibilty * added the changes made on sam by cyyever * fixed the tests for samhq * added some the code * added some code related to init file issue during merge conflicts * remobved the merge conflicts * added changes mentioned by aruther and mobap * added changes mentioned by aruther and mobap * solving quality checks * added the changes for input clearly * added the changes * added changes in mask generation file rgearding model inputs and sam hq quargs in processor file * added changes in processor file * added the Setup -> setupclass conversion * added the code mentioned for processor * added changes for the code * added some code * added some code * added some code --------- Co-authored-by: Pablo Montalvo <[email protected]>
1 parent 9c5b131 commit 65e9402

22 files changed

+4926
-1
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,8 @@
10171017
title: Qwen2VL
10181018
- local: model_doc/sam
10191019
title: Segment Anything
1020+
- local: model_doc/sam_hq
1021+
title: Segment Anything High Quality
10201022
- local: model_doc/shieldgemma2
10211023
title: ShieldGemma2
10221024
- local: model_doc/siglip

docs/source/en/model_doc/sam_hq.md

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SAM-HQ
2+
3+
## Overview
4+
5+
SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu.
6+
7+
The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability.
8+
9+
![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png)
10+
11+
12+
SAM-HQ introduces several key improvements over the original SAM model:
13+
14+
1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction
15+
2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details
16+
3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B
17+
4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality
18+
5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy
19+
20+
The abstract from the paper is the following:
21+
22+
*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.*
23+
24+
Tips:
25+
26+
- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details
27+
- The model predicts binary masks with more accurate boundaries and better handling of thin structures
28+
- Like SAM, the model performs better with input 2D points and/or input bounding boxes
29+
- You can prompt multiple points for the same image and predict a single high-quality mask
30+
- The model maintains SAM's zero-shot generalization capabilities
31+
- SAM-HQ only adds ~0.5% additional parameters compared to SAM
32+
- Fine-tuning the model is not supported yet
33+
34+
This model was contributed by [sushmanth](https://huggingface.co/sushmanth).
35+
The original code can be found [here](https://github.com/SysCV/SAM-HQ).
36+
37+
Below is an example on how to run mask generation given an image and a 2D point:
38+
39+
```python
40+
import torch
41+
from PIL import Image
42+
import requests
43+
from transformers import SamHQModel, SamHQProcessor
44+
45+
device = "cuda" if torch.cuda.is_available() else "cpu"
46+
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
47+
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
48+
49+
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
50+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
51+
input_points = [[[450, 600]]] # 2D location of a window in the image
52+
53+
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
54+
with torch.no_grad():
55+
outputs = model(**inputs)
56+
57+
masks = processor.image_processor.post_process_masks(
58+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
59+
)
60+
scores = outputs.iou_scores
61+
```
62+
63+
You can also process your own masks alongside the input images in the processor to be passed to the model:
64+
65+
```python
66+
import torch
67+
from PIL import Image
68+
import requests
69+
from transformers import SamHQModel, SamHQProcessor
70+
71+
device = "cuda" if torch.cuda.is_available() else "cpu"
72+
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
73+
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
74+
75+
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
76+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
77+
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
78+
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1")
79+
input_points = [[[450, 600]]] # 2D location of a window in the image
80+
81+
inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device)
82+
with torch.no_grad():
83+
outputs = model(**inputs)
84+
85+
masks = processor.image_processor.post_process_masks(
86+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
87+
)
88+
scores = outputs.iou_scores
89+
```
90+
91+
92+
## Resources
93+
94+
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ:
95+
96+
- Demo notebook for using the model (coming soon)
97+
- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ)
98+
99+
## SamHQConfig
100+
101+
[[autodoc]] SamHQConfig
102+
103+
## SamHQVisionConfig
104+
105+
[[autodoc]] SamHQVisionConfig
106+
107+
## SamHQMaskDecoderConfig
108+
109+
[[autodoc]] SamHQMaskDecoderConfig
110+
111+
## SamHQPromptEncoderConfig
112+
113+
[[autodoc]] SamHQPromptEncoderConfig
114+
115+
## SamHQProcessor
116+
117+
[[autodoc]] SamHQProcessor
118+
119+
## SamHQVisionModel
120+
121+
[[autodoc]] SamHQVisionModel
122+
123+
124+
## SamHQModel
125+
126+
[[autodoc]] SamHQModel
127+
- forward

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
from .rt_detr_v2 import *
255255
from .rwkv import *
256256
from .sam import *
257+
from .sam_hq import *
257258
from .seamless_m4t import *
258259
from .seamless_m4t_v2 import *
259260
from .segformer import *

src/transformers/models/auto/configuration_auto.py

+5
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@
286286
("rt_detr_v2", "RTDetrV2Config"),
287287
("rwkv", "RwkvConfig"),
288288
("sam", "SamConfig"),
289+
("sam_hq", "SamHQConfig"),
290+
("sam_hq_vision_model", "SamHQVisionConfig"),
289291
("sam_vision_model", "SamVisionConfig"),
290292
("seamless_m4t", "SeamlessM4TConfig"),
291293
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
@@ -658,6 +660,8 @@
658660
("rt_detr_v2", "RT-DETRv2"),
659661
("rwkv", "RWKV"),
660662
("sam", "SAM"),
663+
("sam_hq", "SAM-HQ"),
664+
("sam_hq_vision_model", "SamHQVisionModel"),
661665
("sam_vision_model", "SamVisionModel"),
662666
("seamless_m4t", "SeamlessM4T"),
663667
("seamless_m4t_v2", "SeamlessM4Tv2"),
@@ -807,6 +811,7 @@
807811
("qwen2_5_vl_text", "qwen2_5_vl"),
808812
("qwen2_vl_text", "qwen2_vl"),
809813
("sam_vision_model", "sam"),
814+
("sam_hq_vision_model", "sam_hq"),
810815
("llama4_text", "llama4"),
811816
("blip_2_qformer", "blip_2"),
812817
]

src/transformers/models/auto/image_processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
142142
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
143143
("sam", ("SamImageProcessor",)),
144+
("sam_hq", ("SamImageProcessor",)),
144145
("segformer", ("SegformerImageProcessor",)),
145146
("seggpt", ("SegGptImageProcessor",)),
146147
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),

src/transformers/models/auto/modeling_auto.py

+8
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@
257257
("rt_detr_v2", "RTDetrV2Model"),
258258
("rwkv", "RwkvModel"),
259259
("sam", "SamModel"),
260+
("sam_hq", "SamHQModel"),
261+
("sam_hq_vision_model", "SamHQVisionModel"),
260262
("sam_vision_model", "SamVisionModel"),
261263
("seamless_m4t", "SeamlessM4TModel"),
262264
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
@@ -1495,6 +1497,12 @@
14951497
]
14961498
)
14971499

1500+
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
1501+
[
1502+
("sam_hq", "SamHQModel"),
1503+
]
1504+
)
1505+
14981506

14991507
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
15001508
[

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
("qwen2_audio", "Qwen2AudioProcessor"),
105105
("qwen2_vl", "Qwen2VLProcessor"),
106106
("sam", "SamProcessor"),
107+
("sam_hq", "SamHQProcessor"),
107108
("seamless_m4t", "SeamlessM4TProcessor"),
108109
("sew", "Wav2Vec2Processor"),
109110
("sew-d", "Wav2Vec2Processor"),
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_sam_hq import *
22+
from .modeling_sam_hq import *
23+
from .processing_samhq import *
24+
else:
25+
import sys
26+
27+
_file = globals()["__file__"]
28+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)