Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/functional-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: 🧪 RF-DETR Functional Tests

on:
pull_request:
branches: [main, develop]
push:
branches: [develop]

permissions:
contents: read
checks: write

jobs:
functional-tests:
name: Run functional test suite
runs-on: ubuntu-latest
timeout-minutes: 20
strategy:
matrix:
python-version: ["3.10"]

steps:
- name: 📥 Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0

- name: 🐍 Install uv and set Python ${{ matrix.python-version }}
uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

- name: 🏗️ Install dependencies
run: |
uv pip install .

- name: ✅ Run functional tests
run: |
uv run python functional_testing.py
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ On image segmentation, RF-DETR Seg (Preview) is 3x faster and more accurate than
- `2025/04/03`: We release early stopping, gradient checkpointing, metrics saving, training resume, TensorBoard and W&B logging support.
- `2025/03/20`: We release RF-DETR real-time object detection model. **Code and checkpoint for RF-DETR-large and RF-DETR-base are available.**

## Advanced Features (Experimental)

The following features are available when installing RF-DETR from source and enable more advanced research and customization use cases:

- **IoU-aware query selection and adaptive query allocation** (detection)
- Optional improvements to query initialization inside the transformer.
- Controlled via `ModelConfig` / CLI flags: `use_iou_aware_query`, `adaptive_query_allocation`.

- **Enhanced segmentation head with mask quality scoring** (segmentation)
- New head that adds mask quality prediction and dynamic refinement.
- Controlled via `ModelConfig` / CLI flags: `enhanced_segmentation`, `mask_quality_prediction`, `dynamic_mask_refinement`.

- **Advanced data augmentations** (training)
- Mosaic, MixUp, and Copy-Paste augmentations implemented in `rfdetr.datasets.advanced_augmentations`.
- Can be plugged into custom training scripts for stronger data augmentation pipelines.

## Results

RF-DETR achieves state-of-the-art performance on both the Microsoft COCO and the RF100-VL benchmarks.
Expand Down Expand Up @@ -180,6 +196,16 @@ You can fine-tune an RF-DETR Nano, Small, Medium, and Base model with a custom d

Visit our [documentation website](https://rfdetr.roboflow.com) to learn more about how to use RF-DETR.

### Testing and CI

For contributors, a small functional test suite is provided to validate core enhancements (IoU-aware queries, adaptive query allocation, and enhanced segmentation head):

```bash
python functional_testing.py
```

GitHub Actions runs these tests automatically via the `functional-tests` workflow on pull requests targeting the main development branches.

## License

Both the code and the weights pretrained on the COCO dataset are released under the [Apache 2.0 license](https://github.com/roboflow/r-flow/blob/main/LICENSE).
Expand Down
198 changes: 198 additions & 0 deletions functional_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#!/usr/bin/env python3
"""Test script for RF-DETR enhancements.

Validates IoU-aware query selection, adaptive query allocation,
and enhanced segmentation head integration.
"""

import torch
import numpy as np
from PIL import Image
import sys
import os

# Add rfdetr to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from rfdetr.models.iou_aware_query_selector import IoUAwareQuerySelector, AdaptiveQueryAllocator
from rfdetr.models.enhanced_segmentation_head import EnhancedSegmentationHead, AdaptiveMaskLoss


def test_iou_aware_query_selector():
"""Test IoU-aware query selector."""
print("Testing IoU-aware query selector...")

# Create dummy data
batch_size = 2
num_queries = 300
feature_dim = 256
num_memory = 1000

# Initialize selector
selector = IoUAwareQuerySelector(
d_model=feature_dim,
num_queries=num_queries
)

# Create dummy inputs
memory = torch.randn(batch_size, num_memory, feature_dim)
spatial_shapes = torch.tensor([[32, 32], [16, 16], [8, 8], [4, 4]])
level_start_index = torch.tensor([0, 1024, 1536, 1792])
reference_points = torch.rand(batch_size, num_queries, 4)

# Forward pass
try:
selected_features, scores = selector(memory, spatial_shapes, level_start_index, reference_points)
assert selected_features.shape == (batch_size, num_queries, feature_dim)
assert scores.shape == (batch_size, num_queries, 1)
print("✓ IoU-aware query selector test passed")
return True
except Exception as e:
print(f"✗ IoU-aware query selector test failed: {e}")
return False


def test_adaptive_query_allocator():
"""Test adaptive query allocator."""
print("Testing adaptive query allocator...")

# Create dummy data
batch_size = 2
num_queries = 300
feature_dim = 256
num_memory = 1000

# Initialize allocator
allocator = AdaptiveQueryAllocator(base_queries=num_queries)

# Create dummy input
memory = torch.randn(batch_size, num_memory, feature_dim)

# Forward pass
try:
allocated_queries = allocator(memory)
assert isinstance(allocated_queries, int)
assert 100 <= allocated_queries <= 600 # Should be within min/max range
print(f"✓ Adaptive query allocator test passed (allocated {allocated_queries} queries)")
return True
except Exception as e:
print(f"✗ Adaptive query allocator test failed: {e}")
return False


def test_enhanced_segmentation_head():
"""Test enhanced segmentation head."""
print("Testing enhanced segmentation head...")

# Create dummy data
batch_size = 2
num_queries = 100
feature_dim = 256
image_size = (512, 512)

# Initialize enhanced segmentation head
seg_head = EnhancedSegmentationHead(
feature_dim=feature_dim,
num_layers=3,
use_quality_prediction=True,
use_dynamic_refinement=True
)

# Create dummy inputs
spatial_features = torch.randn(batch_size, feature_dim, 64, 64)
query_features = [torch.randn(batch_size, num_queries, feature_dim) for _ in range(3)]
bbox_features = torch.rand(batch_size, num_queries, 4)

# Forward pass
try:
mask_logits, quality_scores = seg_head(
spatial_features, query_features, image_size, bbox_features
)

assert len(mask_logits) == 3 # Should have 3 layers
assert mask_logits[-1].shape == (batch_size, num_queries, 128, 128)
assert quality_scores is not None
assert quality_scores.shape == (batch_size, num_queries, 1)
print("✓ Enhanced segmentation head test passed")
return True
except Exception as e:
print(f"✗ Enhanced segmentation head test failed: {e}")
return False




def test_integration():
"""Test integration with existing RF-DETR components."""
print("Testing integration...")

try:
# Test imports work correctly
from rfdetr.models.transformer import Transformer
from rfdetr.config import ModelConfig

# Test configuration with new features
config = ModelConfig(
encoder="dinov2_windowed_small",
out_feature_indexes=[2, 5, 8, 11],
dec_layers=3,
projector_scale=["P3", "P4", "P5"],
hidden_dim=256,
patch_size=14,
num_windows=4,
sa_nheads=8,
ca_nheads=8,
dec_n_points=4,
resolution=640,
positional_encoding_size=10000,
use_iou_aware_query=True,
adaptive_query_allocation=True,
enhanced_segmentation=True,
mask_quality_prediction=True,
dynamic_mask_refinement=True
)

assert config.use_iou_aware_query == True
assert config.adaptive_query_allocation == True
assert config.enhanced_segmentation == True

print("✓ Integration test passed")
return True
except Exception as e:
print(f"✗ Integration test failed: {e}")
return False


def main():
"""Run all tests."""
print("Running RF-DETR Enhancement Tests\n")
print("=" * 50)

tests = [
test_iou_aware_query_selector,
test_adaptive_query_allocator,
test_enhanced_segmentation_head,
test_integration,
]

passed = 0
total = len(tests)

for test in tests:
if test():
passed += 1
print()

print("=" * 50)
print(f"Tests passed: {passed}/{total}")

if passed == total:
print("🎉 All tests passed! Ready for PR.")
return 0
else:
print("❌ Some tests failed. Please fix issues before creating PR.")
return 1


if __name__ == "__main__":
exit(main())
5 changes: 5 additions & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class ModelConfig(BaseModel):
cls_loss_coef: float = 1.0
segmentation_head: bool = False
mask_downsample_ratio: int = 4
use_iou_aware_query: bool = False
adaptive_query_allocation: bool = False
enhanced_segmentation: bool = False
mask_quality_prediction: bool = True
dynamic_mask_refinement: bool = True


class RFDETRBaseConfig(ModelConfig):
Expand Down
Loading