Skip to content

[On-device Training] Yolo custom loss #19464




Discussed in #19390

Originally posted by Marouan-st February 2, 2024

I would like to implement a custom loss to be able to train on-device a yolov4-tiny model for object detection.

To compute the loss some post-processing must be performed on the output of the model, like computing bboxes iou and sum several losses (class loss + confidence loss + iou loss: cross entropy losses): see

I don't see how to implement all these needed computations in the custom loss, especially how to provide the different losses with the post-processed input, since onnx loss functions takes as input String arguments (input name).

I'm using a yolov4-tiny model compiled from darknet and converted to onnx from a tensorflow implementation of the model.

The Torch implementation of this loss function (for the model i'm using) would look like this (inspired by this yolov4 loss tensorflow implementation):

def compute_loss(pred, conv, label, bboxes, STRIDES=[16, 32], NUM_CLASS=1, IOU_LOSS_THRESH=0.5, i=0):
      conv_shape  = conv.size()
      batch_size  = conv_shape[0]
      output_size = conv_shape[1]
      input_size  = STRIDES[i] * output_size
      conv = torch.reshape(conv, (batch_size, output_size, output_size, 3, 5 + NUM_CLASS))

      conv_raw_conf = conv[:, :, :, :, 4:5]
      conv_raw_prob = conv[:, :, :, :, 5:]

      pred_xywh     = pred[:, :, :, :, 0:4]
      pred_conf     = pred[:, :, :, :, 4:5]

      label_xywh    = label[:, :, :, :, 0:4]
      respond_bbox  = label[:, :, :, :, 4:5]
      label_prob    = label[:, :, :, :, 5:]

      giou = torch.unsqueeze(bbox_giou(pred_xywh, label_xywh), 0) # Here not sure...
      input_size =

      bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
      giou_loss = respond_bbox * bbox_loss_scale * (1- giou)

      iou = bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])
      max_iou = torch.unsqueeze(torch.max(iou), 0)

      respond_bgd = (1.0 - respond_bbox) * (max_iou < IOU_LOSS_THRESH).to(torch.float32)

      conf_focal = torch.pow(respond_bbox - pred_conf, 2)

      conf_loss = conf_focal * (
              respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox)
              respond_bgd * torch.nn.functional.cross_entropy(input=conv_raw_conf, target=respond_bbox)

      prob_loss = respond_bbox * torch.nn.functional.cross_entropy(input=conv_raw_prob, target=label_prob)
      giou_loss = torch.mean(torch.sum(giou_loss))
      conf_loss = torch.mean(torch.sum(conf_loss, axis=[1,2,3,4]))
      prob_loss = torch.mean(torch.sum(prob_loss, axis=[1,2,3,4]))

      return giou_loss + conf_loss + prob_loss

def bbox_iou(bboxes1, bboxes2):
      @param bboxes1: (a, b, ..., 4)
      @param bboxes2: (A, B, ..., 4)
          x:X is 1:n or n:n or n:1
      @return (max(a,A), max(b,B), ...)
      ex) (4,):(3,4) -> (3,)
          (2,1,4):(2,3,4) -> (2,3)
      bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
      bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]

      bboxes1_coor = torch.concat(
              bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
              bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
      bboxes2_coor = torch.concat(
              bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
              bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,

      left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])

      inter_section = torch.maximum(right_down - left_up, 0.0)
      inter_area = inter_section[..., 0] * inter_section[..., 1]

      union_area = bboxes1_area + bboxes2_area - inter_area

      iou = torch.div(inter_area, union_area)

      return iou

def bbox_giou(bboxes1, bboxes2):
      Generalized IoU
      @param bboxes1: (a, b, ..., 4)
      @param bboxes2: (A, B, ..., 4)
          x:X is 1:n or n:n or n:1
      @return (max(a,A), max(b,B), ...)
      ex) (4,):(3,4) -> (3,)
          (2,1,4):(2,3,4) -> (2,3)
      bboxes1_area = bboxes1[..., 2] * bboxes1[..., 3]
      bboxes2_area = bboxes2[..., 2] * bboxes2[..., 3]
      bboxes1_coor = torch.concat(
              bboxes1[..., :2] - bboxes1[..., 2:] * 0.5,
              bboxes1[..., :2] + bboxes1[..., 2:] * 0.5,
      bboxes2_coor = torch.concat(
              bboxes2[..., :2] - bboxes2[..., 2:] * 0.5,
              bboxes2[..., :2] + bboxes2[..., 2:] * 0.5,
      left_up = torch.maximum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      right_down = torch.minimum(bboxes1_coor[..., 2:], bboxes2_coor[..., 2:])
      inter_section = torch.maximum(right_down - left_up, 0.0)
      inter_area = inter_section[..., 0] * inter_section[..., 1]
      union_area = bboxes1_area + bboxes2_area - inter_area
      iou = torch.div(inter_area, union_area)
      enclose_left_up = torch.minimum(bboxes1_coor[..., :2], bboxes2_coor[..., :2])
      enclose_right_down = torch.maximum(
          bboxes1_coor[..., 2:], bboxes2_coor[..., 2:]
      enclose_section = enclose_right_down - enclose_left_up
      enclose_area = enclose_section[..., 0] * enclose_section[..., 1]
      giou = iou - torch.div(enclose_area - union_area, enclose_area)
      return giou

Any suggestions?

Thank you




No one assigned


    trainingissues related to ONNX Runtime training; typically submitted using template


    No type


    No projects


    No milestone


    None yet


    No branches or pull requests

    Issue actions