-
-
Notifications
You must be signed in to change notification settings - Fork 298
implement IOU Loss #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
''' | ||
Based on: | ||
https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/86a370aa2cadea6ba7e5dffb2efc4bacc4c863ea/utils/box/box_utils.py#L47 | ||
|
||
Distance-IoU Loss: Faster and Better Learning for Bounding Box Regression | ||
https://arxiv.org/pdf/1911.08287.pdf | ||
Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression | ||
https://giou.stanford.edu/GIoU.pdf | ||
UnitBox: An Advanced Object Detection Network | ||
https://arxiv.org/pdf/1608.01471.pdf | ||
|
||
Important!!! (in case of c_iou_loss) | ||
targets -> bboxes1, preds -> bboxes2 | ||
''' | ||
|
||
import torch | ||
from torch import nn | ||
import numpy as np | ||
|
||
eps = 10e-16 | ||
|
||
|
||
def compute_iou(bboxes1, bboxes2): | ||
"bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" | ||
assert bboxes1.size(0) == bboxes2.size(0) | ||
area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) | ||
area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) | ||
min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) | ||
max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) | ||
min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) | ||
max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) | ||
|
||
inter = torch.where(min_x2 - max_x1 > 0, min_x2 - max_x1, torch.tensor(0.)) * \ | ||
torch.where(min_y2 - max_y1 > 0, min_y2 - max_y1, torch.tensor(0.)) | ||
union = area1 + area2 - inter | ||
iou = inter / union | ||
iou = torch.clamp(iou, min=0, max=1.0) | ||
return iou | ||
|
||
|
||
def compute_g_iou(bboxes1, bboxes2): | ||
"box1 of shape [N, 4] and box2 of shape [N, 4]" | ||
#assert bboxes1.size(0) == bboxes2.size(0) | ||
area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) | ||
area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) | ||
min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) | ||
max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) | ||
min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) | ||
max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) | ||
inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) | ||
union = area1 + area2 - inter | ||
C = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) * \ | ||
(torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) | ||
g_iou = inter / union - (C - union) / C | ||
g_iou = torch.clamp(g_iou, min=0, max=1.0) | ||
return g_iou | ||
|
||
|
||
def compute_d_iou(bboxes1, bboxes2): | ||
"bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" | ||
#assert bboxes1.size(0) == bboxes2.size(0) | ||
area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) | ||
area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) | ||
min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) | ||
max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) | ||
min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) | ||
max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) | ||
inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) | ||
union = area1 + area2 - inter | ||
center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2 | ||
center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2 | ||
center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2 | ||
center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2 | ||
|
||
# squared euclidian distance between the target and predicted bboxes | ||
d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2 | ||
# squared length of the diagonal of the minimum bbox that encloses both bboxes | ||
c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + ( | ||
torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2 | ||
d_iou = inter / union - d_2 / c_2 | ||
d_iou = torch.clamp(d_iou, min=-1.0, max=1.0) | ||
|
||
return d_iou | ||
|
||
|
||
def compute_c_iou(bboxes1, bboxes2): | ||
"bboxes1 of shape [N, 4] and bboxes2 of shape [N, 4]" | ||
#assert bboxes1.size(0) == bboxes2.size(0) | ||
w1 = bboxes1[:, 2] - bboxes1[:, 0] | ||
h1 = bboxes1[:, 3] - bboxes1[:, 1] | ||
w2 = bboxes2[:, 2] - bboxes2[:, 0] | ||
h2 = bboxes2[:, 3] - bboxes2[:, 1] | ||
area1 = w1 * h1 | ||
area2 = w2 * h2 | ||
min_x2 = torch.min(bboxes1[:, 2], bboxes2[:, 2]) | ||
max_x1 = torch.max(bboxes1[:, 0], bboxes2[:, 0]) | ||
min_y2 = torch.min(bboxes1[:, 3], bboxes2[:, 3]) | ||
max_y1 = torch.max(bboxes1[:, 1], bboxes2[:, 1]) | ||
|
||
inter = torch.clamp(min_x2 - max_x1, min=0) * torch.clamp(min_y2 - max_y1, min=0) | ||
union = area1 + area2 - inter | ||
|
||
center_x1 = (bboxes1[:, 2] + bboxes1[:, 0]) / 2 | ||
center_y1 = (bboxes1[:, 3] + bboxes1[:, 1]) / 2 | ||
center_x2 = (bboxes2[:, 2] + bboxes2[:, 0]) / 2 | ||
center_y2 = (bboxes2[:, 3] + bboxes2[:, 1]) / 2 | ||
# squared euclidian distance between the target and predicted bboxes | ||
d_2 = (center_x1 - center_x2) ** 2 + (center_y1 - center_y2) ** 2 | ||
# squared length of the diagonal of the minimum bbox that encloses both bboxes | ||
c_2 = (torch.max(bboxes1[:, 2], bboxes2[:, 2]) - torch.min(bboxes1[:, 0], bboxes2[:, 0])) ** 2 + ( | ||
torch.max(bboxes1[:, 3], bboxes2[:, 3]) - torch.min(bboxes1[:, 1], bboxes2[:, 1])) ** 2 | ||
iou = inter / union | ||
v = 4 / np.pi ** 2 * (np.arctan(w1 / h1) - np.arctan(w2 / h2)) ** 2 | ||
with torch.no_grad(): | ||
S = 1 - iou | ||
alpha = v / (S + v + eps) | ||
c_iou = iou - (d_2 / c_2 + alpha * v) | ||
c_iou = torch.clamp(c_iou, min=-1.0, max=1.0) | ||
return c_iou | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,8 @@ | |
import torch.nn.functional as F | ||
|
||
from typing import Optional, List | ||
|
||
from .anchors import decode_box_outputs | ||
from .iou_loss import * | ||
|
||
def focal_loss(logits, targets, alpha: float, gamma: float, normalizer): | ||
"""Compute the focal loss between `logits` and the golden `target` values. | ||
|
@@ -119,15 +120,46 @@ def _box_loss(box_outputs, box_targets, num_positives, delta: float = 0.1): | |
return box_loss | ||
|
||
|
||
|
||
class IouLoss(nn.Module): | ||
|
||
def __init__(self, losstype='Giou', reduction='mean'): | ||
super(IouLoss, self).__init__() | ||
self.reduction = reduction | ||
self.loss = losstype | ||
|
||
def forward(self, target_bboxes, pred_bboxes): | ||
num = target_bboxes.shape[0] | ||
if self.loss == 'Iou': | ||
loss = torch.sum(1.0 - compute_iou(target_bboxes, pred_bboxes)) | ||
else: | ||
if self.loss == 'Giou': | ||
loss = torch.sum(1.0 - compute_g_iou(target_bboxes, pred_bboxes)) | ||
else: | ||
if self.loss == 'Diou': | ||
loss = torch.sum(1.0 - compute_d_iou(target_bboxes, pred_bboxes)) | ||
else: | ||
loss = torch.sum(1.0 - compute_c_iou(target_bboxes, pred_bboxes)) | ||
|
||
if self.reduction == 'mean': | ||
return loss / num | ||
else: | ||
return loss | ||
|
||
|
||
class DetectionLoss(nn.Module): | ||
def __init__(self, config): | ||
def __init__(self, config, anchors, use_iou_loss = False): | ||
super(DetectionLoss, self).__init__() | ||
self.config = config | ||
self.num_classes = config.num_classes | ||
self.alpha = config.alpha | ||
self.gamma = config.gamma | ||
self.delta = config.delta | ||
self.box_loss_weight = config.box_loss_weight | ||
self.use_iou_loss = use_iou_loss | ||
if self.use_iou_loss: | ||
self.anchors = anchors | ||
self.iou_loss = IouLoss() | ||
|
||
def forward( | ||
self, cls_outputs: List[torch.Tensor], box_outputs: List[torch.Tensor], | ||
|
@@ -161,6 +193,11 @@ def forward( | |
|
||
cls_losses = [] | ||
box_losses = [] | ||
if self.use_iou_loss: | ||
box_outputs_list = [] | ||
cls_targets_list = [] | ||
box_targets_list = [] | ||
|
||
for l in range(levels): | ||
cls_targets_at_level = cls_targets[l] | ||
box_targets_at_level = box_targets[l] | ||
|
@@ -182,12 +219,29 @@ def forward( | |
cls_loss = cls_loss.view(bs, height, width, -1, self.num_classes) | ||
cls_loss *= (cls_targets_at_level != -2).unsqueeze(-1).float() | ||
cls_losses.append(cls_loss.sum()) | ||
if not self.use_iou_loss: | ||
box_losses.append(_box_loss( | ||
box_outputs[l].permute(0, 2, 3, 1), | ||
box_targets_at_level, | ||
num_positives_sum, | ||
delta=self.delta)) | ||
|
||
else: | ||
box_outputs_list.append(box_outputs[l].permute(0, 2, 3, 1).reshape([bs, -1, 4])) | ||
cls_targets_list.append(cls_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 1])) | ||
box_targets_list.append(box_targets_at_level.permute(0, 2, 3, 1).reshape([bs, -1, 4])) | ||
|
||
|
||
if self.use_iou_loss: | ||
# apply bounding box regression to anchors | ||
for k in range(box_outputs_list.shape[0]): | ||
pred_boxes = decode_box_outputs(box_outputs_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also a bug, ty, tx, th, tw = rel_codes.unbind(dim=1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will check it, I think I was not working with the repo's last commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for your reply, looking forward your new correct code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mind share your variable giou loss with effdet version? @hamadichihaoui |
||
target_boxes = decode_box_outputs(box_targets_list[k].T.float(), self.anchors.boxes.T, output_xyxy=True) | ||
# indices where an anchor is assigned to target box | ||
indices = box_targets_list[k] == 0.0 | ||
pred_boxes = torch.clamp(pred_boxes, 0) | ||
box_losses.append(self.iou_loss(target_boxes[indices.view(-1)], pred_boxes[indices.view(-1)])) | ||
|
||
box_losses.append(_box_loss( | ||
box_outputs[l].permute(0, 2, 3, 1), | ||
box_targets_at_level, | ||
num_positives_sum, | ||
delta=self.delta)) | ||
|
||
# Sum per level losses to total loss. | ||
cls_loss = torch.sum(torch.stack(cls_losses, dim=-1), dim=-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems a bug, box_outputs_list is a list, has not shape attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I did a mistake,
for k in range(len(box_outputs_list)):