Skip to content

Commit baf3f45

Browse files
authored
Merge pull request #626 from salomethirot-arm/getSAD-upstream
Optimize the Neon implementation of xGetSAD_NxN
2 parents af8bfb8 + 96e442a commit baf3f45

File tree

2 files changed

+92
-92
lines changed

2 files changed

+92
-92
lines changed

source/Lib/CommonLib/arm/neon/RdCost_neon.cpp

Lines changed: 80 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -897,14 +897,6 @@ Distortion xGetHADs_neon( const DistParam &rcDtParam )
897897

898898
return uiSum >> DISTORTION_PRECISION_ADJUSTMENT(rcDtParam.bitDepth);
899899
}
900-
901-
Distortion xGetHAD2SADs_neon( const DistParam &rcDtParam )
902-
{
903-
Distortion distHad = xGetHADs_neon<false>( rcDtParam );
904-
Distortion distSad = RdCost::xGetSAD_SIMD<SIMD_EVERYWHERE_EXTENSION_LEVEL>( rcDtParam );
905-
906-
return std::min( distHad, 2*distSad);
907-
}
908900
#endif // defined( TARGET_SIMD_X86 )
909901

910902
template<bool isCalCentrePos>
@@ -996,106 +988,103 @@ void xGetSADX5_16xN_neon(const DistParam& rcDtParam, Distortion* cost, bool isCa
996988
xGetSADX5_16xN_neon_impl<false>( rcDtParam, cost );
997989
}
998990

999-
template< int iWidth >
1000-
Distortion xGetSAD_NxN_neon( const DistParam &rcDtParam )
991+
static inline Distortion xGetSAD_generic_neon( const DistParam& rcDtParam, const int width )
1001992
{
993+
if( width < 4 )
994+
{
995+
return RdCost::xGetSAD( rcDtParam );
996+
}
1002997

1003-
const short* pSrc1 = (const short*)rcDtParam.org.buf;
1004-
const short* pSrc2 = (const short*)rcDtParam.cur.buf;
1005-
int iRows = rcDtParam.org.height;
1006-
int iSubShift = rcDtParam.subShift;
1007-
int iSubStep = ( 1 << iSubShift );
1008-
const int iStrideSrc1 = rcDtParam.org.stride * iSubStep;
1009-
const int iStrideSrc2 = rcDtParam.cur.stride * iSubStep;
1010-
1011-
uint32_t uiSum = 0;
1012-
int16x8_t vzero_16 = vdupq_n_s16(0);
998+
const int16_t* src1 = rcDtParam.org.buf;
999+
const int16_t* src2 = rcDtParam.cur.buf;
1000+
int height = rcDtParam.org.height;
1001+
int subShift = rcDtParam.subShift;
1002+
int subStep = 1 << subShift;
1003+
const int strideSrc1 = rcDtParam.org.stride * subStep;
1004+
const int strideSrc2 = rcDtParam.cur.stride * subStep;
10131005

1014-
if( iWidth == 4 )
1006+
uint32x4_t sum_u32[2] = { vdupq_n_u32( 0 ), vdupq_n_u32( 0 ) };
1007+
Distortion sum = 0;
1008+
do
10151009
{
1016-
if( iRows == 4 && iSubShift == 0 )
1010+
int w = width;
1011+
1012+
const int16_t* src1_ptr = src1;
1013+
const int16_t* src2_ptr = src2;
1014+
1015+
while( w >= 16 )
10171016
{
1018-
int16x8_t vsrc1 = vcombine_s16( vld1_s16( ( const int16_t* )pSrc1 ), vld1_s16( ( const int16_t* )( &pSrc1[iStrideSrc1] ) ) );
1019-
int16x8_t vsrc2 = vcombine_s16( vld1_s16( ( const int16_t* )pSrc2 ), vld1_s16( ( const int16_t* )( &pSrc2[iStrideSrc2] ) ) );
1020-
int32x4_t vsum =
1021-
vmovl_s16( vget_low_s16( pairwise_add_s16x8( vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) ), vzero_16 ) ) );
1022-
vsrc1 = vcombine_s16( vld1_s16( ( const int16_t* )( &pSrc1[2 * iStrideSrc1] ) ), vld1_s16( ( const int16_t* )( &pSrc1[3 * iStrideSrc1] ) ) );
1023-
vsrc2 = vcombine_s16( vld1_s16( ( const int16_t* )( &pSrc2[2 * iStrideSrc2] ) ), vld1_s16( ( const int16_t* )( &pSrc2[3 * iStrideSrc2] ) ) );
1024-
vsum = vaddq_s32(
1025-
vsum, vmovl_s16( vget_low_s16( pairwise_add_s16x8( vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) ), vzero_16 ) ) ) );
1026-
uiSum = horizontal_add_s32x4( vsum );
1017+
const int16x8_t s1_lo = vld1q_s16( src1_ptr );
1018+
const int16x8_t s1_hi = vld1q_s16( src1_ptr + 8 );
1019+
const int16x8_t s2_lo = vld1q_s16( src2_ptr );
1020+
const int16x8_t s2_hi = vld1q_s16( src2_ptr + 8 );
1021+
1022+
const uint16x8_t abs_lo = vreinterpretq_u16_s16( vabdq_s16( s1_lo, s2_lo ) );
1023+
const uint16x8_t abs_hi = vreinterpretq_u16_s16( vabdq_s16( s1_hi, s2_hi ) );
1024+
1025+
sum_u32[0] = vpadalq_u16( sum_u32[0], abs_lo );
1026+
sum_u32[1] = vpadalq_u16( sum_u32[1], abs_hi );
1027+
1028+
src1_ptr += 16;
1029+
src2_ptr += 16;
1030+
w -= 16;
10271031
}
1028-
else
1032+
1033+
if( w >= 8 )
10291034
{
1030-
int32x4_t vsum32 = vdupq_n_s32(0);
1031-
for( int iY = 0; iY < iRows; iY += iSubStep )
1032-
{
1033-
int32x4_t vsrc1 = vmovl_s16( vld1_s16( ( const int16_t* )pSrc1 ) );
1034-
int32x4_t vsrc2 = vmovl_s16( vld1_s16( ( const int16_t* )pSrc2 ) );
1035-
vsum32 = vaddq_s32( vsum32, vabsq_s32( vsubq_s32( vsrc1, vsrc2 ) ) );
1035+
const int16x8_t s1 = vld1q_s16( src1_ptr );
1036+
const int16x8_t s2 = vld1q_s16( src2_ptr );
10361037

1037-
pSrc1 += iStrideSrc1;
1038-
pSrc2 += iStrideSrc2;
1039-
}
1040-
uiSum = horizontal_add_s32x4( vsum32 );
1038+
const uint16x8_t abs = vreinterpretq_u16_s16( vabdq_s16( s1, s2 ) );
1039+
sum_u32[0] = vpadalq_u16( sum_u32[0], abs );
1040+
1041+
src1_ptr += 8;
1042+
src2_ptr += 8;
1043+
w -= 8;
10411044
}
1042-
}
1043-
else
1044-
{
1045-
static constexpr bool earlyExitAllowed = iWidth >= 64;
1046-
int32x4_t vsum32 = vdupq_n_s32( 0 );
1047-
int checkExit = 3;
10481045

1049-
for( int iY = 0; iY < iRows; iY+=iSubStep )
1046+
if( w >= 4 )
10501047
{
1051-
int16x8_t vsrc1 = vld1q_s16( ( const int16_t* )( pSrc1 ) );
1052-
int16x8_t vsrc2 = vld1q_s16( ( const int16_t* )( pSrc2 ) );
1053-
int16x8_t vsum16 = vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) );
1048+
const int16x4_t s1 = vld1_s16( src1_ptr );
1049+
const int16x4_t s2 = vld1_s16( src2_ptr );
10541050

1055-
if( iWidth >= 16 )
1056-
{
1057-
vsrc1 = vld1q_s16( ( const int16_t* )( &pSrc1[8] ) );
1058-
vsrc2 = vld1q_s16( ( const int16_t* )( &pSrc2[8] ) );
1059-
vsum16 = vaddq_s16( vsum16, vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) ) );
1051+
const uint16x4_t abs = vreinterpret_u16_s16( vabd_s16( s1, s2 ) );
1052+
sum_u32[0] = vaddw_u16( sum_u32[0], abs );
10601053

1061-
for( int iX = 16; iX < iWidth; iX += 16 )
1062-
{
1063-
vsrc1 = vld1q_s16( ( const int16_t* )( &pSrc1[iX] ) );
1064-
vsrc2 = vld1q_s16( ( const int16_t* )( &pSrc2[iX] ) );
1065-
vsum16 = vaddq_s16( vsum16, vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) ) );
1054+
src1_ptr += 4;
1055+
src2_ptr += 4;
1056+
w -= 4;
1057+
}
10661058

1067-
vsrc1 = vld1q_s16( ( const int16_t* )( &pSrc1[iX + 8] ) );
1068-
vsrc2 = vld1q_s16( ( const int16_t* )( &pSrc2[iX + 8] ) );
1069-
vsum16 = vaddq_s16( vsum16, vabsq_s16( vsubq_s16( vsrc1, vsrc2 ) ) );
1070-
}
1071-
}
1059+
while( w != 0 )
1060+
{
1061+
sum += abs( src1_ptr[w - 1] - src2_ptr[w - 1] );
10721062

1073-
int32x4_t vsumtemp = vpaddlq_s16( vsum16);
1063+
w--;
1064+
}
10741065

1075-
if( earlyExitAllowed ) vsum32 = pairwise_add_s32x4( vsum32, vsumtemp );
1076-
else vsum32 = vaddq_s32 ( vsum32, vsumtemp );
1066+
src1 += strideSrc1;
1067+
src2 += strideSrc2;
1068+
height -= subStep;
1069+
} while( height != 0 );
10771070

1078-
pSrc1 += iStrideSrc1;
1079-
pSrc2 += iStrideSrc2;
1071+
sum += horizontal_add_u32x4( vaddq_u32( sum_u32[0], sum_u32[1] ) );
1072+
sum <<= subShift;
1073+
return sum >> DISTORTION_PRECISION_ADJUSTMENT( rcDtParam.bitDepth );
1074+
}
10801075

1081-
if( earlyExitAllowed && checkExit == 0 )
1082-
{
1083-
Distortion distTemp = vgetq_lane_s32(vsum32, 0);
1084-
distTemp <<= iSubShift;
1085-
distTemp >>= DISTORTION_PRECISION_ADJUSTMENT( rcDtParam.bitDepth );
1086-
if( distTemp > rcDtParam.maximumDistortionForEarlyExit ) return distTemp;
1087-
checkExit = 3;
1088-
}
1089-
else if( earlyExitAllowed )
1090-
{
1091-
checkExit--;
1092-
}
1093-
}
1094-
uiSum = horizontal_add_s32x4( vsum32 );
1095-
}
1076+
Distortion xGetHAD2SADs_neon( const DistParam& rcDtParam )
1077+
{
1078+
Distortion distHad = xGetHADs_neon<false>( rcDtParam );
1079+
Distortion distSad = xGetSAD_generic_neon( rcDtParam, rcDtParam.org.width );
10961080

1097-
uiSum <<= iSubShift;
1098-
return uiSum >> DISTORTION_PRECISION_ADJUSTMENT(rcDtParam.bitDepth);
1081+
return std::min( distHad, 2 * distSad );
1082+
}
1083+
1084+
template<int iWidth>
1085+
Distortion xGetSAD_NxN_neon( const DistParam& rcDtParam )
1086+
{
1087+
return xGetSAD_generic_neon( rcDtParam, iWidth );
10991088
}
11001089

11011090
Distortion xGetSADwMask_neon( const DistParam& rcDtParam )
@@ -1389,10 +1378,9 @@ void RdCost::_initRdCostARM<NEON>()
13891378
m_afpDistortFuncX5[1] = xGetSADX5_16xN_neon;
13901379

13911380
m_afpDistortFunc[0][DF_SAD_WITH_MASK] = xGetSADwMask_neon;
1392-
1393-
#if defined( TARGET_SIMD_X86 )
13941381
m_afpDistortFunc[0][DF_HAD_2SAD ] = xGetHAD2SADs_neon;
13951382

1383+
#if defined( TARGET_SIMD_X86 )
13961384
m_afpDistortFunc[0][DF_HAD] = xGetHADs_neon<false>;
13971385
m_afpDistortFunc[0][DF_HAD2] = xGetHADs_neon<false>;
13981386
m_afpDistortFunc[0][DF_HAD4] = xGetHADs_neon<false>;
@@ -1410,14 +1398,14 @@ void RdCost::_initRdCostARM<NEON>()
14101398
m_afpDistortFunc[0][DF_HAD32_fast] = xGetHADs_neon<true>;
14111399
m_afpDistortFunc[0][DF_HAD64_fast] = xGetHADs_neon<true>;
14121400
m_afpDistortFunc[0][DF_HAD128_fast] = xGetHADs_neon<true>;
1401+
#endif // defined( TARGET_SIMD_X86 )
14131402

14141403
m_afpDistortFunc[0][DF_SAD4 ] = xGetSAD_NxN_neon<4>;
14151404
m_afpDistortFunc[0][DF_SAD8 ] = xGetSAD_NxN_neon<8>;
14161405
m_afpDistortFunc[0][DF_SAD16 ] = xGetSAD_NxN_neon<16>;
14171406
m_afpDistortFunc[0][DF_SAD32 ] = xGetSAD_NxN_neon<32>;
14181407
m_afpDistortFunc[0][DF_SAD64 ] = xGetSAD_NxN_neon<64>;
14191408
m_afpDistortFunc[0][DF_SAD128] = xGetSAD_NxN_neon<128>;
1420-
#endif // defined( TARGET_SIMD_X86 )
14211409

14221410
m_wtdPredPtr[0] = lumaWeightedSSE_neon<0>;
14231411
m_wtdPredPtr[1] = lumaWeightedSSE_neon<1>;

source/Lib/CommonLib/arm/neon/sum_neon.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ static inline int16_t horizontal_add_s16x8( const int16x8_t a )
6767
#endif
6868
}
6969

70+
static inline uint32_t horizontal_add_u32x4( const uint32x4_t a )
71+
{
72+
#if REAL_TARGET_AARCH64
73+
return vaddvq_u32( a );
74+
#else
75+
const uint64x2_t b = vpaddlq_u32( a );
76+
const uint32x2_t c =
77+
vadd_u32( vreinterpret_u32_u64( vget_low_u64( b ) ), vreinterpret_u32_u64( vget_high_u64( b ) ) );
78+
return vget_lane_u32( c, 0 );
79+
#endif
80+
}
81+
7082
static inline int horizontal_add_s32x4( const int32x4_t a )
7183
{
7284
#if REAL_TARGET_AARCH64

0 commit comments

Comments
 (0)