Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
args: ["--print-width=120"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.4
rev: v0.15.5
hooks:
- id: ruff-check
args: [--fix]
Expand All @@ -51,7 +51,7 @@ repos:
args: ["--number"]

- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
rev: v2.4.2
hooks:
- id: codespell
# args: [--write-changes]
Expand Down
16 changes: 13 additions & 3 deletions src/rfdetr/models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,19 @@ def forward(self, outputs, targets, group_detr=1):
C = C + self.cost_mask_ce * cost_mask_ce + self.cost_mask_dice * cost_mask_dice
C = C.view(bs, num_queries, -1).float().cpu() # convert to float because bfloat16 doesn't play nicely with CPU

# we assume any good match will not cause NaN or Inf, so we replace them with a large value
max_cost = C.max() if C.numel() > 0 else 0
C[C.isinf() | C.isnan()] = max_cost * 2
# We assume any good match will not cause NaN or Inf, so replace invalid
# entries with a finite value that is larger than every valid cost.
finite_mask = torch.isfinite(C)
if not finite_mask.all():
if finite_mask.any():
finite_costs = C[finite_mask]
max_cost = finite_costs.max()
# Add the largest absolute finite cost so the replacement stays
# strictly larger than every valid entry, even if all costs are negative.
replacement_cost = max_cost + finite_costs.abs().max() + 1
else:
replacement_cost = C.new_tensor(1.0)
C[~finite_mask] = replacement_cost

sizes = [len(v["boxes"]) for v in targets]
indices = []
Expand Down
45 changes: 45 additions & 0 deletions tests/models/test_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

import pytest
import torch

from rfdetr.models.matcher import HungarianMatcher


@pytest.mark.parametrize(
"invalid_value",
[
pytest.param(float("nan"), id="nan"),
pytest.param(float("inf"), id="inf"),
],
)
def test_matcher_replaces_non_finite_costs_before_assignment(invalid_value: float) -> None:
"""Matcher should sanitize non-finite costs so assignment still succeeds."""
matcher = HungarianMatcher()
outputs = {
"pred_logits": torch.tensor([[[0.0], [10.0]]], dtype=torch.float32),
"pred_boxes": torch.tensor(
[
[
[invalid_value, 0.5, 0.2, 0.2],
[0.5, 0.5, 0.2, 0.2],
]
],
dtype=torch.float32,
),
}
targets = [
{
"labels": torch.tensor([0], dtype=torch.int64),
"boxes": torch.tensor([[0.5, 0.5, 0.2, 0.2]], dtype=torch.float32),
}
]

matched_queries, matched_targets = matcher(outputs, targets)[0]

assert matched_queries.tolist() == [1]
assert matched_targets.tolist() == [0]
Loading