diff --git a/mmrotate/core/post_processing/bbox_nms_rotated.py b/mmrotate/core/post_processing/bbox_nms_rotated.py index 4affc2c9c..561ed5640 100644 --- a/mmrotate/core/post_processing/bbox_nms_rotated.py +++ b/mmrotate/core/post_processing/bbox_nms_rotated.py @@ -30,6 +30,9 @@ def multiclass_nms_rotated(multi_bboxes, tuple (dets, labels, indices (optional)): tensors of shape (k, 5), \ (k), and (k). Dets are boxes with scores. Labels are 0-based. """ + # --- Windows / FP16 safety fix --- + multi_bboxes = multi_bboxes.cpu() + multi_scores = multi_scores.cpu() num_classes = multi_scores.size(1) - 1 # exclude background category if multi_bboxes.shape[1] > 5: @@ -39,7 +42,7 @@ def multiclass_nms_rotated(multi_bboxes, multi_scores.size(0), num_classes, 5) scores = multi_scores[:, :-1] - labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) + labels = torch.arange(num_classes, dtype=torch.long) labels = labels.view(1, -1).expand_as(scores) bboxes = bboxes.reshape(-1, 5) scores = scores.reshape(-1) diff --git a/mmrotate/core/visualization/image.py b/mmrotate/core/visualization/image.py index b042007d2..6e4eaf55f 100644 --- a/mmrotate/core/visualization/image.py +++ b/mmrotate/core/visualization/image.py @@ -139,6 +139,8 @@ def imshow_det_rbboxes(img, assert bboxes is not None and bboxes.shape[1] == 6 scores = bboxes[:, -1] inds = scores > score_thr + if hasattr(inds, "to"): + inds = inds.to(bboxes.device) bboxes = bboxes[inds, :] labels = labels[inds] if segms is not None: