Skip to content

Commit 3b9a8a6

Browse files
committed
Rewrite xGetSADwMask_neon to improve performance
The subtraction and absolute value can be replaced with a single absolute difference instruction, and we can use a pair of accumulators to avoid a bottleneck on the accumulating multiply-add instruction latency. When benchmarking with LLVM-20 on recent Arm Neoverse micro-architectures, this reduces the time spent in xGetSADwMask_neon by ~19%. Additionally, enable this kernel even when TARGET_SIMD_X86 is disabled, since it does not rely on the x86 intrinsics code.
1 parent d5e3270 commit 3b9a8a6

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

source/Lib/CommonLib/arm/neon/RdCost_neon.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,6 @@ POSSIBILITY OF SUCH DAMAGE.
6969
namespace vvenc
7070
{
7171

72-
static inline int32x4_t neon_madd_16( int16x8_t a, int16x8_t b )
73-
{
74-
int32x4_t c = vmull_s16( vget_low_s16( a ), vget_low_s16( b ) );
75-
int32x4_t d = vmull_s16( vget_high_s16( a ), vget_high_s16( b ) );
76-
return pairwise_add_s32x4( c, d );
77-
}
78-
7972
#if ENABLE_SIMD_OPT_DIST && defined( TARGET_SIMD_ARM )
8073

8174
// The xGetHADs_neon functions depend on the SIMDe kernels being enabled
@@ -1119,46 +1112,55 @@ static inline int16x8_t reverse_vector_s16( int16x8_t x )
11191112
Distortion xGetSADwMask_neon( const DistParam& rcDtParam )
11201113
{
11211114
if (rcDtParam.org.width < 4 || rcDtParam.bitDepth > 10 || rcDtParam.applyWeight)
1115+
{
11221116
return RdCost::xGetSADwMask(rcDtParam);
1117+
}
11231118

11241119
const short *src1 = (const short *) rcDtParam.org.buf;
11251120
const short *src2 = (const short *) rcDtParam.cur.buf;
11261121
const short *weightMask = (const short *) rcDtParam.mask;
11271122
int rows = rcDtParam.org.height;
11281123
int cols = rcDtParam.org.width;
11291124
int subShift = rcDtParam.subShift;
1130-
int subStep = (1 << subShift);
1125+
int subStep = 1 << subShift;
11311126
const int strideSrc1 = rcDtParam.org.stride * subStep;
11321127
const int strideSrc2 = rcDtParam.cur.stride * subStep;
11331128
const int strideMask = rcDtParam.maskStride * subStep;
11341129

1135-
Distortion sum = 0;
1130+
int32x4_t sum0 = vdupq_n_s32( 0 );
1131+
int32x4_t sum1 = vdupq_n_s32( 0 );
11361132

1137-
int32x4_t vsum32 = vdupq_n_s32( 0 );
1138-
1139-
for (int y = 0; y < rows; y += subStep)
1133+
do
11401134
{
1141-
for (int x = 0; x < cols; x += 8)
1135+
int x = 0;
1136+
do
11421137
{
1143-
int16x8_t vsrc1 = vld1q_s16( ( const int16_t* )(&src1[x] ) );
1144-
int16x8_t vsrc2 = vld1q_s16( ( const int16_t* )(&src2[x] ) );
1138+
int16x8_t vsrc1 = vld1q_s16( src1 + x );
1139+
int16x8_t vsrc2 = vld1q_s16( src2 + x );
11451140
int16x8_t vmask;
11461141
if (rcDtParam.stepX == -1)
11471142
{
1148-
vmask = vld1q_s16( ( const int16_t* )( ( &weightMask[ x ] ) - ( x << 1 ) - ( 8 - 1 ) ) );
1143+
vmask = vld1q_s16( weightMask - x - 7 );
11491144
vmask = reverse_vector_s16( vmask );
11501145
}
11511146
else
11521147
{
1153-
vmask = vld1q_s16( ( const int16_t* ) (&weightMask[x]));
1148+
vmask = vld1q_s16( weightMask + x );
11541149
}
1155-
vsum32 = vaddq_s32(vsum32, neon_madd_16(vmask, vabsq_s16(vsubq_s16(vsrc1, vsrc2))));
1156-
}
1150+
int16x8_t diff = vabdq_s16( vsrc1, vsrc2 );
1151+
sum0 = vmlal_s16( sum0, vget_low_s16( diff ), vget_low_s16( vmask ) );
1152+
sum1 = vmlal_s16( sum1, vget_high_s16( diff ), vget_high_s16( vmask ) );
1153+
1154+
x += 8;
1155+
} while( x != cols );
1156+
11571157
src1 += strideSrc1;
11581158
src2 += strideSrc2;
11591159
weightMask += strideMask;
1160-
}
1161-
sum = horizontal_add_s32x4( vsum32 );
1160+
rows -= subStep;
1161+
} while( rows != 0 );
1162+
1163+
Distortion sum = horizontal_add_s32x4( vaddq_s32( sum0, sum1 ) );
11621164
sum <<= subShift;
11631165
return sum >> DISTORTION_PRECISION_ADJUSTMENT(rcDtParam.bitDepth);
11641166
}
@@ -1397,6 +1399,8 @@ void RdCost::_initRdCostARM<NEON>()
13971399
{
13981400
m_afpDistortFuncX5[1] = xGetSADX5_16xN_neon;
13991401

1402+
m_afpDistortFunc[0][DF_SAD_WITH_MASK] = xGetSADwMask_neon;
1403+
14001404
#if defined( TARGET_SIMD_X86 )
14011405
m_afpDistortFunc[0][DF_HAD_2SAD ] = xGetHAD2SADs_neon;
14021406

@@ -1424,9 +1428,6 @@ void RdCost::_initRdCostARM<NEON>()
14241428
m_afpDistortFunc[0][DF_SAD32 ] = xGetSAD_NxN_neon<32>;
14251429
m_afpDistortFunc[0][DF_SAD64 ] = xGetSAD_NxN_neon<64>;
14261430
m_afpDistortFunc[0][DF_SAD128] = xGetSAD_NxN_neon<128>;
1427-
1428-
m_afpDistortFunc[0][DF_SAD_WITH_MASK] = xGetSADwMask_neon;
1429-
14301431
#endif // defined( TARGET_SIMD_X86 )
14311432

14321433
m_wtdPredPtr[0] = lumaWeightedSSE_neon<0>;

0 commit comments

Comments
 (0)