Skip to content

Commit 8600ee5

Browse files
committed
feat: Transformers API
1 parent c4847ad commit 8600ee5

27 files changed

Lines changed: 7547 additions & 147 deletions

MODEL_CARD.md

Lines changed: 424 additions & 0 deletions
Large diffs are not rendered by default.

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ lint:
1313

1414
.PHONY: typecheck
1515
typecheck:
16-
poetry run mypy .
16+
uv run ty check
1717

1818
.PHONY: test
1919
test:

README.md

Lines changed: 104 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
[![Deploy](https://github.com/creative-graphic-design/MVANet/actions/workflows/deploy.yaml/badge.svg)](https://github.com/creative-graphic-design/MVANet/actions/workflows/deploy.yaml)
66
[![PyPI](https://img.shields.io/pypi/v/mvanet.svg)](https://pypi.python.org/pypi/mvanet)
77

8-
This is a fork of the original [MVANet](https://github.com/qianyu-dlut/MVANet), with bug fixes and packaging improvements.
8+
This is a fork of the original [MVANet](https://github.com/qianyu-dlut/MVANet), with bug fixes, packaging improvements, and transformers-compatible API.
9+
10+
MVANet is a Multi-view Aggregation Network for Dichotomous Image Segmentation, presented at CVPR 2024 (Highlight). It achieves state-of-the-art performance for high-precision object segmentation from high-resolution images.
911

1012
## Installation
1113

@@ -15,9 +17,52 @@ pip install mvanet
1517

1618
## Usage
1719

20+
### Transformers API (Recommended)
21+
22+
The transformers-compatible API provides better integration with the HuggingFace ecosystem:
23+
24+
```python
25+
from PIL import Image
26+
from mvanet.transformers import (
27+
MVANetConfig,
28+
MVANetForImageSegmentation,
29+
MVANetImageProcessor,
30+
)
31+
32+
# Load image
33+
image = Image.open("/path/to/image.png")
34+
35+
# Initialize model and processor
36+
config = MVANetConfig()
37+
model = MVANetForImageSegmentation(config)
38+
processor = MVANetImageProcessor()
39+
40+
# Preprocess
41+
inputs = processor(image, return_tensors="pt")
42+
43+
# Inference
44+
outputs = model(**inputs)
45+
46+
# Post-process
47+
masks = processor.post_process_semantic_segmentation(
48+
outputs, target_sizes=[image.size[::-1]]
49+
)
50+
```
51+
52+
### Legacy API (Optional)
53+
54+
The original predictor API is available as an optional dependency:
55+
56+
```shell
57+
# Install with legacy API support
58+
pip install mvanet[original]
59+
# or with uv
60+
uv sync --group original
61+
```
62+
1863
```python
1964
from PIL import Image
20-
from mvanet.predictor import MVANetPredictor
65+
from mvanet import MVANetPredictor
2166

2267
test_image = Image.open("/path/to/image.png")
2368

@@ -32,94 +77,73 @@ predicted_mask = predictor(test_image, output_type="map")
3277
predicted_mask.save("mask.png")
3378
```
3479

35-
---
36-
37-
The official repo of the CVPR 2024 paper (Highlight), [Multi-view Aggregation Network for Dichotomous Image Segmentation](https://arxiv.org/abs/2404.07445)
38-
39-
40-
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te1)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te1?p=multi-view-aggregation-network-for)
41-
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te2)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te2?p=multi-view-aggregation-network-for)
42-
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te3)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te3?p=multi-view-aggregation-network-for)
43-
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-te4)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-te4?p=multi-view-aggregation-network-for)
44-
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-view-aggregation-network-for/dichotomous-image-segmentation-on-dis-vd)](https://paperswithcode.com/sota/dichotomous-image-segmentation-on-dis-vd?p=multi-view-aggregation-network-for)
45-
## Introduction
46-
Dichotomous Image Segmentation (DIS) has recently emerged towards high-precision object segmentation from high-resolution natural images. When designing an effective DIS model, the main challenge is how to balance the semantic dispersion of high-resolution targets in the small receptive field and the loss of high-precision details in the large receptive field. Existing methods rely on tedious multiple encoder-decoder streams and stages to gradually complete the global localization and local refinement.
47-
48-
Human visual system captures regions of interest by observing them from multiple views. Inspired by it, we model DIS as a multi-view object perception problem and provide a parsimonious multi-view aggregation network (MVANet), which unifies the feature fusion of the distant view and close-up view into a single stream with one encoder-decoder structure. Specifically, we split the high-resolution input images from the original view into the distant view images with global information and close-up view images with local details. Thus, they can constitute a set of complementary multi-view low-resolution input patches.
49-
<p align="center">
50-
<img src="https://github.com/qianyu-dlut/MVANet/assets/73575386/2cff2cc2-ca24-469b-98ab-ed2585329609" alt="image" width="900"/>
51-
</p>
52-
53-
Moreover, two efficient transformer-based multi-view complementary localization and refinement modules (MCLM & MCRM) are proposed to jointly capturing the localization and restoring the boundary details of the targets.
54-
<p align="center">
55-
<img src="https://github.com/qianyu-dlut/MVANet/assets/73575386/14c3e234-bdfe-49a5-a5ed-c82cc776d947" alt="image" width="900"/>
56-
</p>
57-
58-
59-
We achieves state-of-the-art performance in terms of almost all metrics on the DIS benchmark dataset.
60-
<p align="center">
61-
<img src="https://github.com/qianyu-dlut/MVANet/assets/73575386/6f3c0c1b-6cc2-4f0d-b563-7dc0c9050a52" alt="image" width="900"/>
62-
</p>
63-
64-
We have optimized the code and achieved an enhanced FPS performance, reaching 15.2.
65-
<p align="center">
66-
<img src="https://github.com/qianyu-dlut/MVANet/assets/73575386/4de86a52-5b55-4095-9a1f-afda40ce7f7a" alt="image" width="500"/>
67-
</p>
68-
69-
Here are some of our visual results:
70-
<p align="center">
71-
<img src="https://github.com/qianyu-dlut/MVANet/assets/73575386/3c4443d8-fd6f-49f3-988d-45215bc1d8e6" alt="image" width="900"/>
72-
</p>
73-
74-
75-
## I. Requiremets
76-
+ python==3.7
77-
+ torch==1.10.0
78-
+ torchvision==0.11.0
79-
+ mmcv-full==1.3.17
80-
+ mmdet==2.17.0
81-
+ mmengine==0.8.1
82-
+ mmsegmentation==0.19.0
83-
+ numpy
84-
+ ttach
85-
+ einops
86-
+ timm
87-
+ scipy
88-
89-
## II. Training
90-
1. Download the pretrained model at [Google Drive](https://drive.google.com/file/d/1-Zi_DtCT8oC2UAZpB3_XoFOIxIweIAyk/view?usp=sharing).
91-
2. Then, you can start training by simply run:
92-
```
93-
python train.py
94-
```
80+
### Test Time Augmentation (TTA)
9581

96-
## III. Testing
97-
1. Update the data path in config file `./utils/config.py` (line 4~8)
98-
2. Replace the existing path with the path to your saved model in `./predict.py` (line 14)
82+
For higher quality predictions, you can use TTA:
9983

100-
You can also download our trained model at [Google Drive](https://drive.google.com/file/d/1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv/view?usp=sharing).
101-
3. Start predicting by:
102-
```
103-
python predict.py
104-
```
105-
4. Change the predicted map path in `./test.py` (line 17) and start testing:
106-
```
107-
python test.py
84+
```python
85+
from mvanet.transformers import MVANetTTAPipeline
86+
87+
# Create TTA pipeline
88+
tta_pipeline = MVANetTTAPipeline(model, processor)
89+
90+
# Run inference with TTA
91+
masks = tta_pipeline([image])
92+
mask_pil = masks[0]
93+
mask_pil.save("mask_tta.png")
10894
```
10995

110-
You can get our prediction maps at [Google Drive](https://drive.google.com/file/d/1z21OMJ0Zl7JYKFxqR3P2YJTT3zay8doq/view?usp=sharing).
111-
## To Do List
112-
- Release our camere-ready paper on arxiv (done)
113-
- Release our training code (done)
114-
- Release our model checkpoints (done)
115-
- Release our prediction maps (done)
96+
## Configuration
11697

117-
## Citations
98+
The model behavior can be customized through `MVANetConfig`:
99+
100+
```python
101+
from mvanet.transformers import MVANetConfig
102+
103+
config = MVANetConfig(
104+
embedding_dim=128, # Embedding dimension throughout the model
105+
backbone_out_channels=[128, 128, 256, 512, 1024], # Backbone output channels (SwinB)
106+
mclm_pool_ratios=[1, 4, 8], # MCLM multi-scale attention ratios
107+
mcrm_pool_ratios=[2, 4, 8], # MCRM multi-scale attention ratios
108+
insmask_hidden_dim=384, # Instance mask head hidden dimension
109+
global_view_scale=0.5, # Global view downscale factor
110+
num_patches=4, # Number of local patches (2x2 grid)
111+
image_size=1024, # Input image size
112+
num_channels=3, # Number of input channels (RGB)
113+
num_labels=1, # Number of output labels (binary segmentation)
114+
)
118115
```
116+
117+
### Key Parameters
118+
119+
- **`mcrm_pool_ratios`**: Controls the pooling ratios in the Multi-crop Refinement Module. Default `[2, 4, 8]` matches the trained model.
120+
- **`global_view_scale`**: Scale factor for creating the global view (downsampled version). Default `0.5` creates a half-resolution global view.
121+
- **`num_patches`**: Number of local patches. Currently only `4` (2x2 grid) is supported.
122+
- **`insmask_hidden_dim`**: Hidden dimension in the instance mask head. Larger values may capture more complex patterns but require more memory.
123+
124+
## Paper and Citation
125+
126+
This implementation is based on the CVPR 2024 paper:
127+
128+
**Multi-view Aggregation Network for Dichotomous Image Segmentation**
129+
Qian Yu, Xiaoqi Zhao, Youwei Pang, Lihe Zhang, Huchuan Lu
130+
[arXiv:2404.07445](https://arxiv.org/abs/2404.07445)
131+
132+
```bibtex
119133
@article{yu2024multi,
120134
title={Multi-view Aggregation Network for Dichotomous Image Segmentation},
121135
author={Yu, Qian and Zhao, Xiaoqi and Pang, Youwei and Zhang, Lihe and Lu, Huchuan},
122136
journal={arXiv preprint arXiv:2404.07445},
123137
year={2024}
124138
}
125139
```
140+
141+
## Links
142+
143+
- **Original Repository**: [qianyu-dlut/MVANet](https://github.com/qianyu-dlut/MVANet)
144+
- **Paper**: [arXiv:2404.07445](https://arxiv.org/abs/2404.07445)
145+
- **Checkpoints**: [Google Drive](https://drive.google.com/file/d/1_gabQXOF03MfXnf3EWDK1d_8wKiOemOv/view?usp=sharing)
146+
147+
## License
148+
149+
This project follows the license of the original MVANet repository.

0 commit comments

Comments
 (0)