Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"pydantic",
"supervision",
"matplotlib",
"soft_moe",
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new dependency soft_moe is added without any version pinning, meaning builds will always pull the latest mutable release from an external third-party package, which increases the risk of a malicious or compromised update being pulled into your supply chain. An attacker who gains control over the soft_moe package or its distribution channel could execute arbitrary code in your environment with the application's privileges. To reduce this risk, pin soft_moe to a specific, vetted version (or hash) and update it intentionally after review.

Copilot uses AI. Check for mistakes.
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ModelConfig(BaseModel):
resolution: int = 560
group_detr: int = 13
gradient_checkpointing: bool = False
MoE: bool = False
MoE_params: List[int] = [32, 1]
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a mutable default value (list) in Pydantic model fields can lead to unexpected behavior. Use Field(default_factory=lambda: [32, 1]) instead of MoE_params: List[int] = [32, 1] to ensure each instance gets a fresh list.

Copilot uses AI. Check for mistakes.
Comment on lines +33 to +34
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming convention for boolean flags and their parameters is inconsistent with the existing codebase. Throughout the codebase, configuration parameters use snake_case (e.g., two_stage, bbox_reparam, lite_refpoint_refine, gradient_checkpointing). The parameter names MoE and MoE_params should be renamed to use_moe (or enable_moe) and moe_params respectively to maintain consistency.

Copilot uses AI. Check for mistakes.

class RFDETRBaseConfig(ModelConfig):
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small"
Expand Down
63 changes: 53 additions & 10 deletions rfdetr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Optional

import torch
from soft_moe import SoftMoELayerWrapper
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The soft_moe dependency is imported unconditionally at the module level, but it's only used when MoE is enabled. This means the library will be required even for users who don't use the MoE feature. Consider making this import conditional to avoid forcing all users to install this dependency. One approach is to import it inside the __init__ method of TransformerDecoderLayer when MoE=True, with a helpful error message if the import fails directing users to install it.

Suggested change
from soft_moe import SoftMoELayerWrapper
try:
from soft_moe import SoftMoELayerWrapper
except ImportError: # soft_moe is optional and only needed for MoE
class SoftMoELayerWrapper: # type: ignore[no-redef]
"""Placeholder wrapper used when soft_moe is not installed.
This is instantiated only if MoE functionality is enabled. If you hit this
error, install the optional `soft_moe` dependency:
pip install soft-moe
"""
def __init__(self, *args, **kwargs) -> None:
raise ImportError(
"soft_moe is required for MoE support in TransformerDecoderLayer. "
"Please install the optional dependency with:\n\n"
" pip install soft-moe\n"
)

Copilot uses AI. Check for mistakes.
import torch.nn.functional as F
from torch import nn, Tensor

Comment on lines +23 to 26
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement for soft_moe should be placed after the standard library imports and before the local imports, following PEP 8 import order conventions. Currently it's between torch and torch.nn.functional imports. Additionally, since this is a conditional dependency (only used when MoE=True), consider either: 1) Making it an optional dependency in pyproject.toml and importing it conditionally within the init method where it's used, or 2) Documenting that soft_moe is required for MoE functionality.

Suggested change
from soft_moe import SoftMoELayerWrapper
import torch.nn.functional as F
from torch import nn, Tensor
import torch.nn.functional as F
from torch import nn, Tensor
from soft_moe import SoftMoELayerWrapper

Copilot uses AI. Check for mistakes.
Expand All @@ -39,6 +40,18 @@ def forward(self, x):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x

class FFNBlock(nn.Module):
def __init__(self, d_model, dim_feedforward, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)

def forward(self, x):
return self.net(x)
Comment on lines +43 to +54
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FFNBlock class is missing mandatory type hints for all parameters and the return type, as required by the coding guidelines. Additionally, it lacks a Google-style docstring explaining its purpose, parameters, and return value. This is required for all new classes according to CONTRIBUTING.md.

Copilot generated this review using guidance from repository custom instructions.

def gen_sineembed_for_position(pos_tensor, dim=128):
# n_query, bs, _ = pos_tensor.size()
Expand Down Expand Up @@ -136,7 +149,8 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300,
num_feature_levels=4, dec_n_points=4,
lite_refpoint_refine=False,
decoder_norm_type='LN',
bbox_reparam=False):
bbox_reparam=False,
MoE=False, MoE_params=[32,1]):
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using mutable default arguments (list) is a Python anti-pattern that can lead to unexpected behavior. The default value MoE_params=[32,1] should be replaced with None and then initialized inside the function. This applies to both the Transformer and TransformerDecoderLayer classes.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter naming MoE and MoE_params is inconsistent with the existing codebase conventions. Throughout the codebase, configuration parameters use snake_case (e.g., two_stage, bbox_reparam, skip_self_attn). These should be renamed to use_moe (or enable_moe) and moe_params to maintain consistency with codebase naming conventions.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent spacing in list literal. The default value [32,1] should have spaces after commas as per PEP 8 style guide: [32, 1]. This applies to all occurrences in the code.

Suggested change
MoE=False, MoE_params=[32,1]):
MoE=False, MoE_params=[32, 1]):

Copilot uses AI. Check for mistakes.
super().__init__()
self.encoder = None

Expand All @@ -145,7 +159,9 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300,
group_detr=group_detr,
num_feature_levels=num_feature_levels,
dec_n_points=dec_n_points,
skip_self_attn=False,)
skip_self_attn=False,
MoE=MoE,
MoE_params=MoE_params)
assert decoder_norm_type in ['LN', 'Identity']
norm = {
"LN": lambda channels: nn.LayerNorm(channels),
Expand Down Expand Up @@ -441,7 +457,7 @@ class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False, group_detr=1,
num_feature_levels=4, dec_n_points=4,
skip_self_attn=False):
skip_self_attn=False, MoE=False, MoE_params=[32,1]):
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent spacing in list literal. The default value [32,1] should have spaces after commas as per PEP 8 style guide: [32, 1].

Suggested change
skip_self_attn=False, MoE=False, MoE_params=[32,1]):
skip_self_attn=False, MoE=False, MoE_params=[32, 1]):

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter naming MoE and MoE_params is inconsistent with the existing codebase conventions. Throughout the codebase, configuration parameters use snake_case (e.g., two_stage, bbox_reparam, skip_self_attn). These should be renamed to use_moe (or enable_moe) and moe_params to maintain consistency with codebase naming conventions.

Copilot uses AI. Check for mistakes.
super().__init__()
Comment on lines +460 to 461
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using mutable default arguments (list) is a Python anti-pattern that can lead to unexpected behavior. The default value MoE_params=[32,1] should be replaced with None and then initialized inside the function.

Suggested change
skip_self_attn=False, MoE=False, MoE_params=[32,1]):
super().__init__()
skip_self_attn=False, MoE=False, MoE_params=None):
super().__init__()
if MoE_params is None:
MoE_params = [32, 1]

Copilot uses AI. Check for mistakes.
# Decoder Self-Attention
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True)
Expand All @@ -453,19 +469,41 @@ def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.
d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points)

self.nhead = ca_nhead

# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)

# Implementation of Feedforward or the MoE Layer (done by @LeosCtrt)
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "done by @LeosCtrt" contains contributor attribution that should be removed from the code. Git history and PR metadata already track authorship. Keep code comments focused on explaining what the code does and why, not who wrote it.

Suggested change
# Implementation of Feedforward or the MoE Layer (done by @LeosCtrt)
# Implementation of the feedforward (FFN) or the MoE layer

Copilot uses AI. Check for mistakes.
self.MoE = MoE
if self.MoE == True:
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing boolean values with == True is not Pythonic. Use if self.MoE: instead of if self.MoE == True:. This applies to both occurrences in the code.

Suggested change
if self.MoE == True:
if self.MoE:

Copilot uses AI. Check for mistakes.
print("\n" + "="*80)
print("Loading Mixture of Expert (MoE) Architecture")
print("="*80)
print(f"Experts Count : {MoE_params[0]}")
print(f"Slots per Expert : {MoE_params[1]}")
print("-"*80)
print("Warning: This custom architecture prevents loading full pretrained weights.")
print("Note : It may be slightly slower but could improve accuracy.")
print("="*80 + "\n")
Comment on lines +476 to +484
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() statements for architecture notifications is not consistent with the project's logging practices. According to the coding guidelines, use from rfdetr.util.logger import get_logger and log messages with logger.info() for user-facing messages or logger.debug() for detailed information. This ensures consistent logging behavior and respects the LOG_LEVEL environment variable.

Copilot generated this review using guidance from repository custom instructions.

self.moe_layer = SoftMoELayerWrapper(
dim=d_model,
num_experts=MoE_params[0],
slots_per_expert=MoE_params[1],
layer=FFNBlock,
d_model=d_model,
dim_feedforward=dim_feedforward,
dropout=dropout
)
Comment on lines +475 to +494
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no validation for the MoE_params parameter. The code assumes it's a list with exactly 2 elements and directly accesses MoE_params[0] and MoE_params[1]. If the user provides an invalid value (e.g., empty list, single element, non-list), this will raise an IndexError. Add validation to ensure MoE_params contains exactly 2 positive integers before using them.

Copilot uses AI. Check for mistakes.
Comment on lines +486 to +494
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MoE implementation will significantly increase memory usage and model parameters. With the default num_experts=32, the FFN layer is replicated 32 times per decoder layer. For a 3-layer decoder, this means 96 expert FFN blocks instead of 3 regular FFN blocks. This should be documented in the docstring or config to help users understand the memory/compute tradeoffs. Consider adding a warning or documentation about the expected memory increase based on the number of experts.

Copilot uses AI. Check for mistakes.
else:
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.activation = _get_activation_fn(activation)
Comment on lines +495 to +499
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The activation variable is only initialized when MoE=False, but if MoE=True, the activation parameter is passed but never used since it's not assigned. This could cause an AttributeError if code later tries to access self.activation when MoE is enabled. Consider either removing the activation parameter entirely when MoE is True or initializing self.activation regardless of the MoE flag to maintain consistency.

Copilot uses AI. Check for mistakes.

self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)

self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)

self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self.group_detr = group_detr

Expand Down Expand Up @@ -521,7 +559,10 @@ def forward_post(self, tgt, memory,

tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
if self.MoE == True:
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing boolean values with == True is not Pythonic. Use if self.MoE: instead of if self.MoE == True:.

Suggested change
if self.MoE == True:
if self.MoE:

Copilot uses AI. Check for mistakes.
tgt2 = self.moe_layer(tgt)
else:
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
Expand Down Expand Up @@ -571,6 +612,8 @@ def build_transformer(args):
lite_refpoint_refine=args.lite_refpoint_refine,
decoder_norm_type=args.decoder_norm,
bbox_reparam=args.bbox_reparam,
MoE=args.MoE,
MoE_params=args.MoE_params
)


Expand Down