Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions rfdetr/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,95 @@ def pad(image, target, padding):
target['masks'], (0, padding[0], 0, padding[1]))
return padded_image, target

def rotate(image, target, angle):
rotated_image = F.rotate(image, angle, expand=True)

w, h = image.size
# original image size

new_w, new_h = rotated_image.size
cx_old, cy_old = w / 2, h / 2
cx_new, cy_new = new_w / 2, new_h / 2

target = target.copy()
target["size"] = torch.tensor([new_h, new_w])

# ============================================================
# 1. Rotate masks
# ============================================================
if "masks" in target:
# masks: (N, H, W) tensor -> convert each to PIL image for safe rotation
rotated_masks = []
for m in target["masks"]:
pil_m = F.to_pil_image(m.byte() * 255)
pil_m = F.rotate(pil_m, angle, expand=True)
rotated_m = torch.from_numpy(np.array(pil_m)).bool()
rotated_masks.append(rotated_m)

target["masks"] = torch.stack(rotated_masks, dim=0)

# ============================================================
# 2. Rotate bounding boxes
# ============================================================
if "boxes" in target:
boxes = target["boxes"] # (N, 4), xyxy format

# convert xyxy into 4 corner points
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
corners = torch.stack([
torch.stack([x1, y1], dim=1),
torch.stack([x2, y1], dim=1),
torch.stack([x2, y2], dim=1),
torch.stack([x1, y2], dim=1),
], dim=1) # (N, 4, 2)

# rotation matrix
theta = torch.tensor(angle * np.pi / 180.0)
R = torch.tensor([
[ torch.cos(theta), -torch.sin(theta)],
[ torch.sin(theta), torch.cos(theta)]
])

# shift corners to original center
corners_shifted = corners - torch.tensor([cx_old, cy_old])

# rotate
rotated = corners_shifted @ R.T

# shift to new center
rotated = rotated + torch.tensor([cx_new, cy_new])

# new xyxy boxes from rotated corners
x_coords = rotated[:, :, 0]
y_coords = rotated[:, :, 1]
new_boxes = torch.stack([
x_coords.min(dim=1).values,
y_coords.min(dim=1).values,
x_coords.max(dim=1).values,
y_coords.max(dim=1).values,
], dim=1)

target["boxes"] = new_boxes

# ============================================================
# 3. Update area if present
# ============================================================
if "area" in target and "boxes" in target:
new_boxes = target["boxes"]
wh = (new_boxes[:, 2:] - new_boxes[:, :2]).clamp(min=0)
target["area"] = wh[:, 0] * wh[:, 1]

return rotated_image, target

class RandomRotate(object):
def __init__(self, max_angle):
self.max_angle = max_angle

def __call__(self, image, target):
# Sample random angle
angle = float(torch.empty(1).uniform_(-self.max_angle, self.max_angle))
return rotate(image, target, angle)


class RandomCrop(object):
def __init__(self, size):
Expand Down
20 changes: 16 additions & 4 deletions rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ def lr_lambda(current_step: int):

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
model_without_ddp.load_state_dict(state_dict, strict=True)
if args.use_ema:
if 'ema_model' in checkpoint:
self.ema_m.module.load_state_dict(clean_state_dict(checkpoint['ema_model']))
Expand Down Expand Up @@ -498,10 +499,17 @@ def lr_lambda(current_step: int):
self.model.eval()

if args.run_test:
best_state_dict = torch.load(output_dir / 'checkpoint_best_total.pth', map_location='cpu', weights_only=False)['model']
model.load_state_dict(best_state_dict)
time.sleep(5)
checkpoint = torch.load(output_dir / 'checkpoint_best_total.pth', map_location='cpu', weights_only=False)
best_state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint

# Clean the state dict to remove 'module.' prefix if present
best_state_dict = clean_state_dict(best_state_dict)

# Load into the unwrapped model to match non-DDP-saved checkpoint keys
model.module.load_state_dict(best_state_dict)

model.eval()

test_stats, _ = evaluate(
model, criterion, postprocess, data_loader_test, base_ds_test, device, args=args
)
Expand Down Expand Up @@ -770,6 +778,8 @@ def get_args_parser():
parser.add_argument('--bbox_loss_coef', default=5, type=float)
parser.add_argument('--giou_loss_coef', default=2, type=float)
parser.add_argument('--focal_alpha', default=0.25, type=float)
parser.add_argument('--hausdorff_loss_coef', default=1.0, type=float,
help="Coefficient for the Hausdorff distance loss")

# Loss
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
Expand Down Expand Up @@ -930,6 +940,7 @@ def populate_args(
bbox_loss_coef=5,
giou_loss_coef=2,
focal_alpha=0.25,
hausdorff_loss_coef=1.0,
aux_loss=True,
sum_group_losses=False,
use_varifocal_loss=False,
Expand Down Expand Up @@ -1041,6 +1052,7 @@ def populate_args(
bbox_loss_coef=bbox_loss_coef,
giou_loss_coef=giou_loss_coef,
focal_alpha=focal_alpha,
hausdorff_loss_coef=hausdorff_loss_coef,
aux_loss=aux_loss,
sum_group_losses=sum_group_losses,
use_varifocal_loss=use_varifocal_loss,
Expand Down
53 changes: 49 additions & 4 deletions rfdetr/models/lwdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import copy
import math
from typing import Callable
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from scipy.spatial.distance import directed_hausdorff

from rfdetr.util import box_ops
from rfdetr.util.misc import (NestedTensor, nested_tensor_from_tensor_list,
Expand Down Expand Up @@ -129,7 +131,7 @@ def export(self):
m.export()

def forward(self, samples: NestedTensor, targets=None):
""" The forward expects a NestedTensor, which consists of:
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

Expand Down Expand Up @@ -503,7 +505,47 @@ def loss_masks(self, outputs, targets, indices, num_boxes):
del target_masks
return losses


def loss_hausdorff(self, outputs, targets, indices, num_boxes):
"""
Compute the Hausdorff distance loss between predicted and target masks.
GPU-accelerated implementation using PyTorch distance computation.
"""
assert 'pred_masks' in outputs
idx = self._get_src_permutation_idx(indices)
src_masks = outputs['pred_masks'][idx]
target_masks = torch.cat([t['masks'][i] for t, (_, i) in zip(targets, indices)], dim=0)

# Handle empty case
if src_masks.numel() == 0:
return {'loss_hausdorff': torch.tensor(0.0, device=src_masks.device)}

# Binarize masks (threshold at 0.5) - stay on GPU
src_masks_binary = (src_masks.sigmoid() > 0.5)
target_masks_binary = (target_masks > 0.5)

# Accumulate loss directly on GPU to avoid list append overhead
hausdorff_loss = torch.tensor(0.0, device=src_masks.device)

for src, tgt in zip(src_masks_binary, target_masks_binary):
# Extract coordinates of positive pixels on GPU
src_coords = torch.nonzero(src, as_tuple=False).float()
tgt_coords = torch.nonzero(tgt, as_tuple=False).float()

# Skip if either mask is empty
if src_coords.shape[0] == 0 or tgt_coords.shape[0] == 0:
continue

# Compute directed Hausdorff distance on GPU
# Forward: max over src of (min distance to any tgt point)
dist_matrix = torch.cdist(src_coords, tgt_coords, p=2)
forward_hd = dist_matrix.min(dim=1)[0].max()
backward_hd = dist_matrix.min(dim=0)[0].max()
hausdorff_loss += torch.max(forward_hd, backward_hd)

# Normalize by num_boxes instead of valid_count to match other losses
hausdorff_loss = hausdorff_loss / num_boxes
return {'loss_hausdorff': hausdorff_loss}

def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
Expand All @@ -522,6 +564,7 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
'cardinality': self.loss_cardinality,
'boxes': self.loss_boxes,
'masks': self.loss_masks,
'hausdorff': self.loss_hausdorff,
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
Expand Down Expand Up @@ -581,7 +624,7 @@ def forward(self, outputs, targets):
return losses


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.75, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
Expand Down Expand Up @@ -836,6 +879,7 @@ def build_criterion_and_postprocessors(args):
if args.segmentation_head:
weight_dict['loss_mask_ce'] = args.mask_ce_loss_coef
weight_dict['loss_mask_dice'] = args.mask_dice_loss_coef
weight_dict['loss_hausdorff'] = args.hausdorff_loss_coef
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
Expand All @@ -848,6 +892,7 @@ def build_criterion_and_postprocessors(args):
losses = ['labels', 'boxes', 'cardinality']
if args.segmentation_head:
losses.append('masks')
losses.append('hausdorff')

try:
sum_group_losses = args.sum_group_losses
Expand All @@ -871,4 +916,4 @@ def build_criterion_and_postprocessors(args):
criterion.to(device)
postprocess = PostProcess(num_select=args.num_select)

return criterion, postprocess
return criterion, postprocess
Loading