Skip to content

Commit f124941

Browse files
authored
Refactor model assets structure, for rfdetr 1.4.3+ (#6)
* Bump `rfdetr` dependency to `1.4.3` for compatibility * Add MD5 hashes for RF-DETR+ XLarge and 2XLarge model weights * Add inference tests for RF-DETR+ XLarge and 2XLarge models with parameterized resolutions
1 parent 57ddaa8 commit f124941

File tree

6 files changed

+102
-12
lines changed

6 files changed

+102
-12
lines changed

.github/workflows/ci-tests-cpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
run: uv pip install -e . --group tests
4949

5050
- name: Run the Test
51-
run: uv run pytest -m "not gpu" --cov=rfdetr_plus --cov-report=xml
51+
run: pytest -m "not gpu" --cov=rfdetr_plus --cov-report=xml
5252

5353
- name: Upload coverage to Codecov
5454
uses: codecov/codecov-action@v5

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ classifiers = [
3535
keywords = ["machine-learning", "deep-learning", "vision", "ML", "DL", "AI", "DETR", "RF-DETR", "Roboflow"]
3636

3737
dependencies = [
38-
"rfdetr>=1.4.1,<2",
38+
"rfdetr>=1.4.3,<2",
3939
]
4040

4141
[project.optional-dependencies]
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33
# Copyright (c) 2026 Roboflow, Inc. All Rights Reserved.
44
# Licensed under the Platform Model License 1.0 [see LICENSE for details]
55
# ------------------------------------------------------------------------
6-
"""Test basic package functionality."""
76

7+
"""RF-DETR+ model assets module."""
88

9-
def test_version() -> None:
10-
"""Test that the package version is accessible."""
11-
from rfdetr_plus import __version__
9+
from rfdetr_plus.assets.model_weights import ModelWeights
1210

13-
assert isinstance(__version__, str)
14-
assert len(__version__) > 0
11+
__all__ = ["ModelWeights"]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# ------------------------------------------------------------------------
2+
# RF-DETR+
3+
# Copyright (c) 2026 Roboflow, Inc. All Rights Reserved.
4+
# Licensed under the Platform Model License 1.0 [see LICENSE for details]
5+
# ------------------------------------------------------------------------
6+
7+
"""
8+
RF-DETR+ Model weights registry.
9+
10+
Provides ModelWeights enum for platform-licensed large-scale models,
11+
compatible with rf-detr's asset structure introduced in version 1.4.3.
12+
"""
13+
14+
from rfdetr.assets.model_weights import ModelWeightAsset, ModelWeightsBase
15+
16+
17+
class ModelWeights(ModelWeightsBase):
18+
"""
19+
Enumeration of RF-DETR+ platform-licensed model assets.
20+
21+
Inherits from rf-detr's ModelWeightsBase to ensure compatibility.
22+
23+
Each enum member's value is a ModelWeightAsset instance containing:
24+
- filename: The local filename for the model weights
25+
- url: The download URL
26+
- md5_hash: The expected MD5 hash for integrity validation
27+
28+
Example:
29+
>>> from rfdetr_plus.assets import ModelWeights
30+
>>> asset = ModelWeights.RF_DETR_XLARGE
31+
>>> asset.filename
32+
'rf-detr-xlarge.pth'
33+
>>> asset.url
34+
'https://storage.googleapis.com/rfdetr/platform-licensed/rf-detr-xlarge.pth'
35+
"""
36+
37+
# Platform-Licensed Detection Models (XLarge and 2XLarge)
38+
# These models are subject to the Platform Model License 1.0
39+
RF_DETR_XLARGE = ModelWeightAsset(
40+
"rf-detr-xlarge.pth",
41+
"https://storage.googleapis.com/rfdetr/platform-licensed/rf-detr-xlarge.pth",
42+
"6ddf834f2bc5bed3214a82f9b0aaeed7",
43+
)
44+
RF_DETR_XXLARGE = ModelWeightAsset(
45+
"rf-detr-xxlarge.pth",
46+
"https://storage.googleapis.com/rfdetr/platform-licensed/rf-detr-xxlarge.pth",
47+
"e3204689c1f0280427e4c33e6a2ac6cd",
48+
)
49+
50+
# All methods inherited from ModelWeightsBase:
51+
# - from_filename(filename) -> Optional[ModelWeightAsset]
52+
# - get_url(filename) -> Optional[str]
53+
# - get_md5(filename) -> Optional[str]
54+
# - list_models() -> list[str]

src/rfdetr_plus/models/downloads.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
# Licensed under the Platform Model License 1.0 [see LICENSE for details]
55
# ------------------------------------------------------------------------
66

7-
PLATFORM_MODELS = {
8-
"rf-detr-xlarge.pth": "https://storage.googleapis.com/rfdetr/platform-licensed/rf-detr-xlarge.pth",
9-
"rf-detr-xxlarge.pth": "https://storage.googleapis.com/rfdetr/platform-licensed/rf-detr-xxlarge.pth",
10-
}
7+
"""
8+
Legacy model downloads dictionary.
9+
10+
DEPRECATED: Use rfdetr_plus.assets.ModelWeights instead.
11+
This dictionary is maintained for backward compatibility only.
12+
"""
13+
14+
from rfdetr_plus.assets import ModelWeights
15+
16+
# Legacy dictionary for backward compatibility
17+
# New code should use rfdetr_plus.assets.ModelWeights
18+
PLATFORM_MODELS = {asset.filename: asset.url for asset in ModelWeights}

tests/test_inference.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# ------------------------------------------------------------------------
2+
# RF-DETR+
3+
# Copyright (c) 2026 Roboflow, Inc. All Rights Reserved.
4+
# Licensed under the Platform Model License 1.0 [see LICENSE for details]
5+
# ------------------------------------------------------------------------
6+
"""Test basic package functionality."""
7+
8+
import numpy as np
9+
import pytest
10+
11+
from rfdetr_plus import RFDETR2XLarge, RFDETRXLarge
12+
13+
14+
@pytest.mark.parametrize(
15+
("model_class", "resolution"),
16+
[
17+
(RFDETRXLarge, 700),
18+
(RFDETR2XLarge, 880),
19+
],
20+
)
21+
def test_model_inference(model_class, resolution) -> None:
22+
"""Test that we can instantiate RF-DETR+ models and run inference."""
23+
# Instantiate and run inference
24+
rf_detr = model_class()
25+
dummy_image = np.random.randint(0, 255, (resolution, resolution, 3), dtype=np.uint8)
26+
27+
# Run inference - this verifies the model can be instantiated and used
28+
predictions = rf_detr.predict(dummy_image, conf_threshold=0.1)
29+
30+
# Verify predictions were returned
31+
assert predictions is not None

0 commit comments

Comments
 (0)