Skip to content

Commit d5ecd10

Browse files
committed
roi/psroi fixed
1 parent 44967ef commit d5ecd10

File tree

4 files changed

+22
-0
lines changed

4 files changed

+22
-0
lines changed

src/plugins/intel_cpu/src/nodes/psroi_pooling.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ PSROIPooling::PSROIPooling(const std::shared_ptr<ov::Node>& op, const GraphConte
8585
const auto defPsroi = ov::as_type_ptr<const ov::op::v1::DeformablePSROIPooling>(op);
8686

8787
noTrans = op->get_input_size() == 2;
88+
inBatchNum = op->get_input_shape(0)[0];
8889
CPU_NODE_ASSERT(op->get_input_shape(0).size() == 4,
8990
"has first input with incorrect rank: " + std::to_string(op->get_input_shape(0).size()));
9091
CPU_NODE_ASSERT(op->get_input_shape(1).size() == 2,
@@ -619,6 +620,7 @@ void PSROIPooling::executeSpecified() {
619620
cpu_parallel->parallel_for(realRois, [&](int currentRoi) {
620621
const float* bottomRois = bottomRoisBeginning + currentRoi * 5;
621622
auto roiBatchInd = static_cast<int>(bottomRois[0]);
623+
OPENVINO_ASSERT(roiBatchInd <= inBatchNum, "required batch index > batch amount");
622624
if (getAlgorithm() == Algorithm::PSROIPoolingAverage) {
623625
executeAverage(srcData, dstData, bottomRois, currentRoi, roiBatchInd, *srcDesc, *dstDesc);
624626
} else if (getAlgorithm() == Algorithm::PSROIPoolingBilinear) {

src/plugins/intel_cpu/src/nodes/psroi_pooling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class PSROIPooling : public Node {
4848
int nh = 0;
4949
int nw = 0;
5050

51+
int inBatchNum = 0;
52+
5153
// for Deformable PSROIPolling
5254
bool noTrans;
5355
int partSize = 1;

src/plugins/intel_cpu/src/nodes/roi_pooling.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ size_t RoiPoolingKey::hash() const {
377377
seed = hash_combine(seed, refParams.mb);
378378
seed = hash_combine(seed, refParams.c);
379379
seed = hash_combine(seed, refParams.nb_c);
380+
seed = hash_combine(seed, refParams.b_num);
380381
seed = hash_combine(seed, refParams.c_block);
381382
seed = hash_combine(seed, refParams.nb_c_blocking);
382383
seed = hash_combine(seed, refParams.ih);
@@ -546,9 +547,12 @@ void ROIPooling::prepareParams() {
546547
const auto& inDims = getParentEdgeAt(0)->getMemory().getStaticDims();
547548
const auto& outDims = getChildEdgeAt(0)->getMemory().getStaticDims();
548549

550+
const auto& featureShape = getParentEdgeAt(0)->getMemory().getStaticDims();
551+
549552
refParams.mb = outDims[0];
550553
refParams.c = rnd_up(inDims[1], refParams.c_block);
551554
refParams.nb_c = refParams.c / refParams.c_block;
555+
refParams.b_num = featureShape[0];
552556
refParams.ih = inDims[2];
553557
refParams.iw = inDims[3];
554558
refParams.oh = outDims[2];
@@ -619,6 +623,9 @@ class ROIPooling::ROIPoolingJitExecutor : public ROIPooling::ROIPoolingExecutor
619623
if (roi_batch_ind == -1) {
620624
break;
621625
}
626+
OPENVINO_ASSERT(0 <= roi_batch_ind && roi_batch_ind <= jpp.b_num,
627+
"takes incorrect roi_ind, max roi_ind = ",
628+
jpp.b_num);
622629
}
623630

624631
cpuParallel->parallel_for4d(MB, cb_work, jpp.oh, jpp.ow, [&](int n, int cbb, int oh, int ow) {
@@ -636,6 +643,9 @@ class ROIPooling::ROIPoolingJitExecutor : public ROIPooling::ROIPoolingExecutor
636643
const auto* src_roi_ptr = &src_roi[roi_off];
637644

638645
auto roi_batch_ind = static_cast<int>(src_roi_ptr[0]);
646+
OPENVINO_ASSERT(0 <= roi_batch_ind && roi_batch_ind <= jpp.b_num,
647+
"takes incorrect roi_ind, max roi_ind = ",
648+
jpp.b_num);
639649

640650
if (jpp.alg == Algorithm::ROIPoolingMax) {
641651
auto roi_start_w = static_cast<int>(round(src_roi_ptr[1] * jpp.spatial_scale));
@@ -758,6 +768,9 @@ class ROIPooling::ROIPoolingRefExecutor : public ROIPooling::ROIPoolingExecutor
758768
if (roi_batch_ind == -1) {
759769
break;
760770
}
771+
OPENVINO_ASSERT(0 <= roi_batch_ind && roi_batch_ind <= jpp.b_num,
772+
"takes incorrect roi_ind, max roi_ind = ",
773+
jpp.b_num);
761774
}
762775

763776
cpuParallel->parallel_for4d(MB, cb_work, jpp.oh, jpp.ow, [&](int n, int cbb, int oh, int ow) {
@@ -780,6 +793,9 @@ class ROIPooling::ROIPoolingRefExecutor : public ROIPooling::ROIPoolingExecutor
780793
const auto* src_roi_ptr = &src_roi[roi_off];
781794

782795
auto roi_batch_ind = static_cast<int>(src_roi_ptr[0]);
796+
OPENVINO_ASSERT(0 <= roi_batch_ind && roi_batch_ind <= jpp.b_num,
797+
"takes incorrect roi_ind, max roi_ind = ",
798+
jpp.b_num);
783799

784800
if (jpp.alg == Algorithm::ROIPoolingMax) {
785801
auto roi_start_w = static_cast<int>(round(src_roi_ptr[1] * jpp.spatial_scale));

src/plugins/intel_cpu/src/nodes/roi_pooling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ struct jit_roi_pooling_params {
2727

2828
int c_block, nb_c, nb_c_blocking;
2929

30+
int b_num;
31+
3032
double spatial_scale;
3133
int pooled_h;
3234
int pooled_w;

0 commit comments

Comments
 (0)