Prerequisite
Environment
昇腾910B
mmcv 2.2.0
Reproduces the problem - code sample
`import numpy as np
import torch
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray([[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],[7.0, 7.0, 8.0, 8.0, 0.4]],dtype=np.float32)
np_boxes2 = np.asarray([[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],[5.0, 5.0, 6.0, 7.0, 0.4]],dtype=np.float32)
np_expect_ious = np.asarray([[0.3708, 0.4351, 0.0000], [0.1104, 0.4487, 0.0424],[0.0000, 0.0000, 0.3622]],dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.3708, 0.4487, 0.3622],dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).to('npu:0')
boxes2 = torch.from_numpy(np_boxes2).to('npu:0')
ious = box_iou_rotated(boxes1, boxes2,mode='iou',aligned=False)
print (np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4))
print (ious)
boxes3 = torch.tensor(np.random.randn(2000,5),dtype=torch.float32).npu()
iou_2 = box_iou_rotated(boxes1, boxes3,mode='iou',aligned=False)
print (iou_2)
`
上述box_iou_rotated中的参数,boxes3的shape(N,5),N大一点就会报错,比如200
Reproduces the problem - command or script
复现代码如上
Reproduces the problem - error message
Additional information
实际使用中需要更大的框数,也就是上文的N,这个接口在cpu或者cuda上,N=200 没有问题。