Skip to content

Commit 2b8cba0

Browse files
authored
Merge branch 'size-release' into docs/wip
2 parents 6bb140a + d6764d8 commit 2b8cba0

File tree

13 files changed

+317
-83
lines changed

13 files changed

+317
-83
lines changed

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ We validated the performance of RF-DETR on both Microsoft COCO and the RF100-VL
3838
| LW-DETR-M | 28.2 | 52.5 | 84.0 | 57.5 | 6.0 |
3939
| YOLO11m | 20.0 | 51.5 | 84.9 | 59.7 | <ins>5.7</ins> |
4040
| YOLOv8m | 28.9 | 50.6 | 85.0 | 59.8 | 6.3 |
41-
| RF-DETR-B | 29.0 | 53.3 | <ins>86.7</ins> | <ins>60.3</ins> | 6.0 |
41+
| RF-DETR-Medium | 33.7 | 54.8 | <ins>86.6</ins> | <ins>60.6</ins> | <ins>4.31</ins> |
4242

4343

4444
<details>
@@ -55,6 +55,7 @@ We validated the performance of RF-DETR on both Microsoft COCO and the RF100-VL
5555
- `2025/03/20`: We release RF-DETR real-time object detection model. **Code and checkpoint for RF-DETR-large and RF-DETR-base are available.**
5656
- `2025/04/03`: We release early stopping, gradient checkpointing, metrics saving, training resume, TensorBoard and W&B logging support.
5757
- `2025/05/16`: We release an 'optimize_for_inference' method which speeds up native PyTorch by up to 2x, depending on platform.
58+
- `2025/07/23`: We release new SOTA model sizes: RF-DETR-Nano, RF-DETR-Small, RF-DETR-Medium.
5859

5960
## Installation
6061

@@ -79,7 +80,22 @@ pip install git+https://github.com/roboflow/rf-detr.git
7980

8081
## Inference
8182

82-
The easiest path to deployment is using Roboflow's [Inference](https://github.com/roboflow/inference) package. You can use model's uploaded to Roboflow's platform with Inference's `infer` method:
83+
The easiest path to deployment is using Roboflow's [Inference](https://github.com/roboflow/inference) package.
84+
85+
You can upload models using `.deploy_to_roboflow` like so:
86+
87+
```python
88+
from rfdetr import RFDETRNano
89+
90+
x = RFDETRNano(pretrain_weights="<path/to/prtrain/weights/dir>")
91+
x.deploy_to_roboflow(
92+
workspace="<your-workspace>",
93+
project_ids=["<your-project-id>"],
94+
api_key="<YOUR_API_KEY>"
95+
)
96+
```
97+
98+
These models will be available to use with Inference's `infer` method:
8399

84100
```python
85101
import os

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ dependencies = [
5858
"pydantic",
5959
"supervision",
6060
"matplotlib",
61+
"roboflow"
6162
]
6263

6364
[project.optional-dependencies]

rfdetr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") is None:
1010
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
1111

12-
from rfdetr.detr import RFDETRBase, RFDETRLarge
12+
from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRNano, RFDETRSmall, RFDETRMedium

rfdetr/config.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
class ModelConfig(BaseModel):
1414
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
1515
out_feature_indexes: List[int]
16-
dec_layers: int = 3
16+
dec_layers: int
1717
two_stage: bool = True
1818
projector_scale: List[Literal["P3", "P4", "P5"]]
1919
hidden_dim: int
20+
patch_size: int
21+
num_windows: int
2022
sa_nheads: int
2123
ca_nheads: int
2224
dec_n_points: int
@@ -27,13 +29,17 @@ class ModelConfig(BaseModel):
2729
num_classes: int = 90
2830
pretrain_weights: Optional[str] = None
2931
device: Literal["cpu", "cuda", "mps"] = DEVICE
30-
resolution: int = 560
32+
resolution: int
3133
group_detr: int = 13
3234
gradient_checkpointing: bool = False
35+
positional_encoding_size: int
3336

3437
class RFDETRBaseConfig(ModelConfig):
3538
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
3639
hidden_dim: int = 256
40+
patch_size: int = 14
41+
num_windows: int = 4
42+
dec_layers: int = 3
3743
sa_nheads: int = 8
3844
ca_nheads: int = 16
3945
dec_n_points: int = 2
@@ -42,6 +48,8 @@ class RFDETRBaseConfig(ModelConfig):
4248
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
4349
out_feature_indexes: List[int] = [2, 5, 8, 11]
4450
pretrain_weights: Optional[str] = "rf-detr-base.pth"
51+
resolution: int = 560
52+
positional_encoding_size: int = 37
4553

4654
class RFDETRLargeConfig(RFDETRBaseConfig):
4755
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_base"
@@ -52,6 +60,33 @@ class RFDETRLargeConfig(RFDETRBaseConfig):
5260
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
5361
pretrain_weights: Optional[str] = "rf-detr-large.pth"
5462

63+
class RFDETRNanoConfig(RFDETRBaseConfig):
64+
out_feature_indexes: List[int] = [3, 6, 9, 12]
65+
num_windows: int = 2
66+
dec_layers: int = 2
67+
patch_size: int = 16
68+
resolution: int = 384
69+
positional_encoding_size: int = 24
70+
pretrain_weights: Optional[str] = "rf-detr-nano.pth"
71+
72+
class RFDETRSmallConfig(RFDETRBaseConfig):
73+
out_feature_indexes: List[int] = [3, 6, 9, 12]
74+
num_windows: int = 2
75+
dec_layers: int = 3
76+
patch_size: int = 16
77+
resolution: int = 512
78+
positional_encoding_size: int = 32
79+
pretrain_weights: Optional[str] = "rf-detr-small.pth"
80+
81+
class RFDETRMediumConfig(RFDETRBaseConfig):
82+
out_feature_indexes: List[int] = [3, 6, 9, 12]
83+
num_windows: int = 2
84+
dec_layers: int = 4
85+
patch_size: int = 16
86+
resolution: int = 576
87+
positional_encoding_size: int = 36
88+
pretrain_weights: Optional[str] = "rf-detr-medium.pth"
89+
5590
class TrainConfig(BaseModel):
5691
lr: float = 1e-4
5792
lr_encoder: float = 1.5e-4
@@ -76,6 +111,7 @@ class TrainConfig(BaseModel):
76111
output_dir: str = "output"
77112
multi_scale: bool = True
78113
expanded_scales: bool = True
114+
do_random_resize_via_padding: bool = False
79115
use_ema: bool = True
80116
num_workers: int = 2
81117
weight_decay: float = 1e-4

rfdetr/datasets/coco.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,13 @@
2727
import rfdetr.datasets.transforms as T
2828

2929

30-
def compute_multi_scale_scales(resolution, expanded_scales=False):
31-
if resolution == 640:
32-
# assume we're doing the original 640x640 and therefore patch_size is 16
33-
patch_size = 16
34-
elif resolution % (14 * 4) == 0:
35-
# assume we're doing some dinov2 resolution variant and therefore patch_size is 14
36-
patch_size = 14
37-
elif resolution % (16 * 4) == 0:
38-
# assume we're doing some other resolution and therefore patch_size is 16
39-
patch_size = 16
40-
else:
41-
raise ValueError(f"Resolution {resolution} is not divisible by 16*4 or 14*4")
30+
def compute_multi_scale_scales(resolution, expanded_scales=False, patch_size=16, num_windows=4):
4231
# round to the nearest multiple of 4*patch_size to enable both patching and windowing
43-
base_num_patches_per_window = resolution // (patch_size * 4)
32+
base_num_patches_per_window = resolution // (patch_size * num_windows)
4433
offsets = [-3, -2, -1, 0, 1, 2, 3, 4] if not expanded_scales else [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
4534
scales = [base_num_patches_per_window + offset for offset in offsets]
46-
proposed_scales = [scale * patch_size * 4 for scale in scales]
47-
proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * 4] # ensure minimum image size
35+
proposed_scales = [scale * patch_size * num_windows for scale in scales]
36+
proposed_scales = [scale for scale in proposed_scales if scale >= patch_size * num_windows * 2] # ensure minimum image size
4837
return proposed_scales
4938

5039

@@ -107,7 +96,7 @@ def __call__(self, image, target):
10796
return image, target
10897

10998

110-
def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False):
99+
def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4):
111100

112101
normalize = T.Compose([
113102
T.ToTensor(),
@@ -117,7 +106,9 @@ def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scal
117106
scales = [resolution]
118107
if multi_scale:
119108
# scales = [448, 512, 576, 640, 704, 768, 832, 896]
120-
scales = compute_multi_scale_scales(resolution, expanded_scales)
109+
scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows)
110+
if skip_random_resize:
111+
scales = [scales[-1]]
121112
print(scales)
122113

123114
if image_set == 'train':
@@ -148,7 +139,7 @@ def make_coco_transforms(image_set, resolution, multi_scale=False, expanded_scal
148139
raise ValueError(f'unknown {image_set}')
149140

150141

151-
def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False):
142+
def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False, expanded_scales=False, skip_random_resize=False, patch_size=16, num_windows=4):
152143
"""
153144
"""
154145

@@ -161,7 +152,9 @@ def make_coco_transforms_square_div_64(image_set, resolution, multi_scale=False,
161152
scales = [resolution]
162153
if multi_scale:
163154
# scales = [448, 512, 576, 640, 704, 768, 832, 896]
164-
scales = compute_multi_scale_scales(resolution, expanded_scales)
155+
scales = compute_multi_scale_scales(resolution, expanded_scales, patch_size, num_windows)
156+
if skip_random_resize:
157+
scales = [scales[-1]]
165158
print(scales)
166159

167160
if image_set == 'train':
@@ -220,9 +213,25 @@ def build(image_set, args, resolution):
220213

221214

222215
if square_resize_div_64:
223-
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
216+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(
217+
image_set,
218+
resolution,
219+
multi_scale=args.multi_scale,
220+
expanded_scales=args.expanded_scales,
221+
skip_random_resize=not args.do_random_resize_via_padding,
222+
patch_size=args.patch_size,
223+
num_windows=args.num_windows
224+
))
224225
else:
225-
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales))
226+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(
227+
image_set,
228+
resolution,
229+
multi_scale=args.multi_scale,
230+
expanded_scales=args.expanded_scales,
231+
skip_random_resize=not args.do_random_resize_via_padding,
232+
patch_size=args.patch_size,
233+
num_windows=args.num_windows
234+
))
226235
return dataset
227236

228237
def build_roboflow(image_set, args, resolution):
@@ -249,7 +258,23 @@ def build_roboflow(image_set, args, resolution):
249258

250259

251260
if square_resize_div_64:
252-
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(image_set, resolution, multi_scale=args.multi_scale))
261+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms_square_div_64(
262+
image_set,
263+
resolution,
264+
multi_scale=args.multi_scale,
265+
expanded_scales=args.expanded_scales,
266+
skip_random_resize=not args.do_random_resize_via_padding,
267+
patch_size=args.patch_size,
268+
num_windows=args.num_windows
269+
))
253270
else:
254-
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set, resolution, multi_scale=args.multi_scale))
271+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(
272+
image_set,
273+
resolution,
274+
multi_scale=args.multi_scale,
275+
expanded_scales=args.expanded_scales,
276+
skip_random_resize=not args.do_random_resize_via_padding,
277+
patch_size=args.patch_size,
278+
num_windows=args.num_windows
279+
))
255280
return dataset

rfdetr/detr.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
except:
2424
pass
2525

26-
from rfdetr.config import RFDETRBaseConfig, RFDETRLargeConfig, TrainConfig, ModelConfig
26+
from rfdetr.config import (
27+
RFDETRBaseConfig,
28+
RFDETRLargeConfig,
29+
RFDETRNanoConfig,
30+
RFDETRSmallConfig,
31+
RFDETRMediumConfig,
32+
TrainConfig,
33+
ModelConfig
34+
)
2735
from rfdetr.main import Model, download_pretrain_weights
2836
from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink
2937
from rfdetr.util.coco_classes import COCO_CLASSES
@@ -32,6 +40,7 @@
3240
class RFDETR:
3341
means = [0.485, 0.456, 0.406]
3442
stds = [0.229, 0.224, 0.225]
43+
size = None
3544

3645
def __init__(self, **kwargs):
3746
self.model_config = self.get_model_config(**kwargs)
@@ -324,12 +333,48 @@ def predict(
324333
detections_list.append(detections)
325334

326335
return detections_list if len(detections_list) > 1 else detections_list[0]
336+
337+
def deploy_to_roboflow(self, workspace: str, project_ids: List[str], api_key: str = None, size: str = None, model_name: str = None):
338+
from roboflow import Roboflow
339+
import shutil
340+
if api_key is None:
341+
api_key = os.getenv("ROBOFLOW_API_KEY")
342+
if api_key is None:
343+
raise ValueError("Set api_key=<KEY> in deploy_to_roboflow or export ROBOFLOW_API_KEY=<KEY>")
344+
345+
346+
rf = Roboflow(api_key=api_key)
347+
workspace = rf.workspace(workspace)
348+
349+
if self.size is None and size is None:
350+
raise ValueError("Must set size for custom architectures")
351+
352+
size = self.size or size
353+
tmp_out_dir = ".roboflow_temp_upload"
354+
os.makedirs(tmp_out_dir, exist_ok=True)
355+
outpath = os.path.join(tmp_out_dir, "weights.pth")
356+
torch.save(
357+
{
358+
"model": self.model.model,
359+
"args": self.model.args
360+
}, outpath
361+
)
362+
363+
out = workspace.deploy_model(
364+
model_type=size,
365+
model_path=tmp_out_dir,
366+
project_ids=project_ids,
367+
model_name=model_name or size + "-uploaded"
368+
)
369+
return out
370+
327371

328372

329373
class RFDETRBase(RFDETR):
330374
"""
331375
Train an RF-DETR Base model (29M parameters).
332376
"""
377+
size = "rfdetr-base"
333378
def get_model_config(self, **kwargs):
334379
return RFDETRBaseConfig(**kwargs)
335380

@@ -338,10 +383,44 @@ def get_train_config(self, **kwargs):
338383

339384
class RFDETRLarge(RFDETR):
340385
"""
341-
Train an RF-DETR Base model.
386+
Train an RF-DETR Large model.
342387
"""
388+
size = "rfdetr-large"
343389
def get_model_config(self, **kwargs):
344390
return RFDETRLargeConfig(**kwargs)
345391

346392
def get_train_config(self, **kwargs):
347393
return TrainConfig(**kwargs)
394+
395+
class RFDETRNano(RFDETR):
396+
"""
397+
Train an RF-DETR Nano model.
398+
"""
399+
size = "rfdetr-nano"
400+
def get_model_config(self, **kwargs):
401+
return RFDETRNanoConfig(**kwargs)
402+
403+
def get_train_config(self, **kwargs):
404+
return TrainConfig(**kwargs)
405+
406+
class RFDETRSmall(RFDETR):
407+
"""
408+
Train an RF-DETR Small model.
409+
"""
410+
size = "rfdetr-small"
411+
def get_model_config(self, **kwargs):
412+
return RFDETRSmallConfig(**kwargs)
413+
414+
def get_train_config(self, **kwargs):
415+
return TrainConfig(**kwargs)
416+
417+
class RFDETRMedium(RFDETR):
418+
"""
419+
Train an RF-DETR Medium model.
420+
"""
421+
size = "rfdetr-medium"
422+
def get_model_config(self, **kwargs):
423+
return RFDETRMediumConfig(**kwargs)
424+
425+
def get_train_config(self, **kwargs):
426+
return TrainConfig(**kwargs)

0 commit comments

Comments
 (0)