|
| 1 | +# SAM-HQ |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu. |
| 6 | + |
| 7 | +The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. |
| 8 | + |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | +SAM-HQ introduces several key improvements over the original SAM model: |
| 13 | + |
| 14 | +1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction |
| 15 | +2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details |
| 16 | +3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B |
| 17 | +4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality |
| 18 | +5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy |
| 19 | + |
| 20 | +The abstract from the paper is the following: |
| 21 | + |
| 22 | +*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.* |
| 23 | + |
| 24 | +Tips: |
| 25 | + |
| 26 | +- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details |
| 27 | +- The model predicts binary masks with more accurate boundaries and better handling of thin structures |
| 28 | +- Like SAM, the model performs better with input 2D points and/or input bounding boxes |
| 29 | +- You can prompt multiple points for the same image and predict a single high-quality mask |
| 30 | +- The model maintains SAM's zero-shot generalization capabilities |
| 31 | +- SAM-HQ only adds ~0.5% additional parameters compared to SAM |
| 32 | +- Fine-tuning the model is not supported yet |
| 33 | + |
| 34 | +This model was contributed by [sushmanth](https://huggingface.co/sushmanth). |
| 35 | +The original code can be found [here](https://github.com/SysCV/SAM-HQ). |
| 36 | + |
| 37 | +Below is an example on how to run mask generation given an image and a 2D point: |
| 38 | + |
| 39 | +```python |
| 40 | +import torch |
| 41 | +from PIL import Image |
| 42 | +import requests |
| 43 | +from transformers import SamHQModel, SamHQProcessor |
| 44 | + |
| 45 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 46 | +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) |
| 47 | +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") |
| 48 | + |
| 49 | +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" |
| 50 | +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") |
| 51 | +input_points = [[[450, 600]]] # 2D location of a window in the image |
| 52 | + |
| 53 | +inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device) |
| 54 | +with torch.no_grad(): |
| 55 | + outputs = model(**inputs) |
| 56 | + |
| 57 | +masks = processor.image_processor.post_process_masks( |
| 58 | + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() |
| 59 | +) |
| 60 | +scores = outputs.iou_scores |
| 61 | +``` |
| 62 | + |
| 63 | +You can also process your own masks alongside the input images in the processor to be passed to the model: |
| 64 | + |
| 65 | +```python |
| 66 | +import torch |
| 67 | +from PIL import Image |
| 68 | +import requests |
| 69 | +from transformers import SamHQModel, SamHQProcessor |
| 70 | + |
| 71 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 72 | +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) |
| 73 | +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") |
| 74 | + |
| 75 | +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" |
| 76 | +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") |
| 77 | +mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" |
| 78 | +segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1") |
| 79 | +input_points = [[[450, 600]]] # 2D location of a window in the image |
| 80 | + |
| 81 | +inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device) |
| 82 | +with torch.no_grad(): |
| 83 | + outputs = model(**inputs) |
| 84 | + |
| 85 | +masks = processor.image_processor.post_process_masks( |
| 86 | + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() |
| 87 | +) |
| 88 | +scores = outputs.iou_scores |
| 89 | +``` |
| 90 | + |
| 91 | + |
| 92 | +## Resources |
| 93 | + |
| 94 | +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ: |
| 95 | + |
| 96 | +- Demo notebook for using the model (coming soon) |
| 97 | +- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ) |
| 98 | + |
| 99 | +## SamHQConfig |
| 100 | + |
| 101 | +[[autodoc]] SamHQConfig |
| 102 | + |
| 103 | +## SamHQVisionConfig |
| 104 | + |
| 105 | +[[autodoc]] SamHQVisionConfig |
| 106 | + |
| 107 | +## SamHQMaskDecoderConfig |
| 108 | + |
| 109 | +[[autodoc]] SamHQMaskDecoderConfig |
| 110 | + |
| 111 | +## SamHQPromptEncoderConfig |
| 112 | + |
| 113 | +[[autodoc]] SamHQPromptEncoderConfig |
| 114 | + |
| 115 | +## SamHQProcessor |
| 116 | + |
| 117 | +[[autodoc]] SamHQProcessor |
| 118 | + |
| 119 | +## SamHQVisionModel |
| 120 | + |
| 121 | +[[autodoc]] SamHQVisionModel |
| 122 | + |
| 123 | + |
| 124 | +## SamHQModel |
| 125 | + |
| 126 | +[[autodoc]] SamHQModel |
| 127 | + - forward |
0 commit comments