Skip to content

Commit 667bb8c

Browse files
authored
Merge pull request #620 from athulya-arm/athulya-arm/motionErrorLumaFrac-sve
Arm: Add SVE implementation of motionErrorLumaFrac_loRes
2 parents 629d4a0 + b85155a commit 667bb8c

File tree

4 files changed

+236
-95
lines changed

4 files changed

+236
-95
lines changed

source/Lib/CommonLib/arm/MCTF_neon.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,97 @@ namespace vvenc
5656

5757
#if defined( TARGET_SIMD_ARM ) && ENABLE_SIMD_OPT_MCTF
5858

59+
enum class FilterCoeffType4
60+
{
61+
SkewLeft = 1, // coeff[3]==0.
62+
SkewRight, // coeff[0]==0.
63+
FullSymmetric, // coeff[0]==coeff[3], coeff[1]==coeff[2].
64+
Generic // Other generic cases.
65+
};
66+
67+
static inline auto selectFilterType4( const int16_t* coeff )
68+
{
69+
auto coeffPtr = reinterpret_cast<const int16_t ( * )[4]>( coeff );
70+
auto idx = coeffPtr - MCTF::m_interpolationFilter4;
71+
72+
switch( idx )
73+
{
74+
case 1:
75+
return FilterCoeffType4::SkewLeft;
76+
case 8:
77+
return FilterCoeffType4::FullSymmetric;
78+
case 15:
79+
return FilterCoeffType4::SkewRight;
80+
default:
81+
return FilterCoeffType4::Generic;
82+
}
83+
}
84+
85+
template<FilterCoeffType4 Type>
86+
static inline int16x8_t motionErrorLumaFrac_loRes1D_neon( const int16x8_t* src, const int16x4_t filter,
87+
const Pel maxSampleValue )
88+
{
89+
constexpr int filterBits = 6 - 1; // Filter weight is 64 >> 1.
90+
switch( Type )
91+
{
92+
case FilterCoeffType4::SkewLeft:
93+
{
94+
//{ -2, 62, 4, 0 }
95+
// => -2s0 + 62s1 + 4s2
96+
// => -2s0 + 64s1 - 2s1 + 4s2
97+
// => 2( 32s1 + 2s2 - ( s0 + s1 ) )
98+
int16x8_t sum01 = vaddq_s16( src[0], src[1] ); // Input is 10-bit.
99+
int16x8_t diff2_01 = vsubq_s16( vshlq_n_s16( src[2], 1 ), sum01 );
100+
int16x8_t sum = vhaddq_s16( diff2_01, vshlq_n_s16( src[1], 5 ) ); // 16-bit.
101+
102+
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
103+
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
104+
}
105+
case FilterCoeffType4::SkewRight:
106+
{
107+
//{ 0, 4, 62, -2 }
108+
// => 4s1 + 62s2 - 2s3
109+
// => 4s1 + 64s2 - 2s2 - 2s3
110+
// => 4s1 + 64s2 -2( s2 + s3 )
111+
// => 2( 2s1 + 32s2 - ( s2 + s3 ))
112+
int16x8_t sum23 = vaddq_s16( src[2], src[3] ); // Input is 10-bit.
113+
int16x8_t diff1_23 = vsubq_s16( vshlq_n_s16( src[1], 1 ), sum23 );
114+
int16x8_t sum = vhaddq_s16( diff1_23, vshlq_n_s16( src[2], 5 ) ); // 16-bit.
115+
116+
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
117+
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
118+
}
119+
case FilterCoeffType4::FullSymmetric:
120+
{
121+
//{ -4, 36, 36, -4 }
122+
// => -4s0 + 36s1 + 36s2 - 4s3
123+
// => -4s0 + 32s1 + 4s1 + 32s2 + 4s2 - 4s3
124+
// => 4( s1 + s2 - s0 - s3 ) + 32( s1 + s2 )
125+
// => 4( ( s1 + s2 ) - ( s0 + s3 ) + 8( s1 + s2 ) )
126+
int16x8_t sum03 = vaddq_s16( src[0], src[3] ); // Input is 10-bit.
127+
int16x8_t sum12 = vaddq_s16( src[1], src[2] );
128+
int16x8_t diff12_03 = vsubq_s16( sum12, sum03 );
129+
int16x8_t sum = vhaddq_s16( diff12_03, vshlq_n_s16( sum12, 3 ) ); // 16-bit.
130+
131+
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 2 ), vdupq_n_s16( 0 ) );
132+
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
133+
}
134+
case FilterCoeffType4::Generic:
135+
default:
136+
{
137+
int16x8_t sum01 = vmulq_lane_s16( src[0], filter, 0 );
138+
sum01 = vmlaq_lane_s16( sum01, src[1], filter, 1 );
139+
int16x8_t sum23 = vmulq_lane_s16( src[2], filter, 2 );
140+
sum23 = vmlaq_lane_s16( sum23, src[3], filter, 3 );
141+
142+
int16x8_t sum = vhaddq_s16( sum01, sum23 );
143+
144+
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
145+
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
146+
}
147+
}
148+
}
149+
59150
static const int32_t xSzm[6] = { 0, 1, 20, 336, 5440, 87296 };
60151

61152
static inline void applyPlanarDeblockingCorrection_common( Pel* dstPel, const ptrdiff_t dstStride, const int32_t x1yzm,
File renamed without changes.

source/Lib/CommonLib/arm/neon/MCTF_neon.cpp

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -64,101 +64,6 @@ POSSIBILITY OF SUCH DAMAGE.
6464
namespace vvenc
6565
{
6666

67-
enum class FilterCoeffType4
68-
{
69-
SkewLeft = 1, // coeff[3]==0.
70-
SkewRight, // coeff[0]==0.
71-
FullSymmetric, // coeff[0]==coeff[3], coeff[1]==coeff[2].
72-
Generic // Other generic cases.
73-
};
74-
75-
static inline auto selectFilterType4( const int16_t* coeff )
76-
{
77-
auto coeffPtr = reinterpret_cast<const int16_t ( * )[4]>( coeff );
78-
auto idx = coeffPtr - MCTF::m_interpolationFilter4;
79-
80-
switch( idx )
81-
{
82-
case 1:
83-
return FilterCoeffType4::SkewLeft;
84-
case 8:
85-
return FilterCoeffType4::FullSymmetric;
86-
case 15:
87-
return FilterCoeffType4::SkewRight;
88-
default:
89-
return FilterCoeffType4::Generic;
90-
}
91-
}
92-
93-
template<FilterCoeffType4 Type>
94-
static inline int16x8_t motionErrorLumaFrac_loRes1D_neon( const int16x8_t* src, const int16x4_t filter,
95-
const Pel maxSampleValue )
96-
{
97-
constexpr int filterBits = 6 - 1; // Filter weight is 64 >> 1.
98-
switch( Type )
99-
{
100-
case FilterCoeffType4::SkewLeft:
101-
{
102-
//{ -2, 62, 4, 0 }
103-
// => -2s0 + 62s1 + 4s2
104-
// => -2s0 + 64s1 - 2s1 + 4s2
105-
// => 2( 32s1 + 2s2 - ( s0 + s1 ) )
106-
int16x8_t sum01 = vaddq_s16( src[0], src[1] ); // Input is 10-bit.
107-
int16x8_t diff2_01 = vsubq_s16( vshlq_n_s16( src[2], 1 ), sum01 );
108-
int16x8_t sum = vhaddq_s16( diff2_01, vshlq_n_s16( src[1], 5 ) ); // 16-bit.
109-
110-
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
111-
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
112-
}
113-
break;
114-
case FilterCoeffType4::SkewRight:
115-
{
116-
//{ 0, 4, 62, -2 }
117-
// => 4s1 + 62s2 - 2s3
118-
// => 4s1 + 64s2 - 2s2 - 2s3
119-
// => 4s1 + 64s2 -2( s2 + s3 )
120-
// => 2( 2s1 + 32s2 - ( s2 + s3 ))
121-
int16x8_t sum23 = vaddq_s16( src[2], src[3] ); // Input is 10-bit.
122-
int16x8_t diff1_23 = vsubq_s16( vshlq_n_s16( src[1], 1 ), sum23 );
123-
int16x8_t sum = vhaddq_s16( diff1_23, vshlq_n_s16( src[2], 5 ) ); // 16-bit.
124-
125-
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
126-
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
127-
}
128-
break;
129-
case FilterCoeffType4::FullSymmetric:
130-
{
131-
//{ -4, 36, 36, -4 }
132-
// => -4s0 + 36s1 + 36s2 - 4s3
133-
// => -4s0 + 32s1 + 4s1 + 32s2 + 4s2 - 4s3
134-
// => 4( s1 + s2 - s0 - s3 ) + 32( s1 + s2 )
135-
// => 4( ( s1 + s2 ) - ( s0 + s3 ) + 8( s1 + s2 ) )
136-
int16x8_t sum03 = vaddq_s16( src[0], src[3] ); // Input is 10-bit.
137-
int16x8_t sum12 = vaddq_s16( src[1], src[2] );
138-
int16x8_t diff12_03 = vsubq_s16( sum12, sum03 );
139-
int16x8_t sum = vhaddq_s16( diff12_03, vshlq_n_s16( sum12, 3 ) ); // 16-bit.
140-
141-
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 2 ), vdupq_n_s16( 0 ) );
142-
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
143-
}
144-
break;
145-
case FilterCoeffType4::Generic:
146-
default:
147-
{
148-
int16x8_t sum01 = vmulq_lane_s16( src[0], filter, 0 );
149-
sum01 = vmlaq_lane_s16( sum01, src[1], filter, 1 );
150-
int16x8_t sum23 = vmulq_lane_s16( src[2], filter, 2 );
151-
sum23 = vmlaq_lane_s16( sum23, src[3], filter, 3 );
152-
153-
int16x8_t sum = vhaddq_s16( sum01, sum23 );
154-
155-
sum = vmaxq_s16( vrshrq_n_s16( sum, filterBits - 1 ), vdupq_n_s16( 0 ) );
156-
return vminq_s16( sum, vdupq_n_s16( maxSampleValue ) );
157-
}
158-
break;
159-
}
160-
}
161-
16267
template<FilterCoeffType4 xType, FilterCoeffType4 yType>
16368
static inline int motionErrorLumaFrac_loRes2D_neon( const Pel* org, const ptrdiff_t origStride, const Pel* buf,
16469
const ptrdiff_t buffStride, int w, int h, const int16_t* xFilter,

source/Lib/CommonLib/arm/sve/MCTF_sve.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ POSSIBILITY OF SUCH DAMAGE.
6060
#if defined( TARGET_SIMD_ARM ) && ENABLE_SIMD_OPT_MCTF
6161

6262
#include "MCTF_neon.h"
63+
#include "mem_neon.h"
6364
#include "neon_sve_bridge.h"
6465
#include <arm_neon.h>
6566
#include <arm_sve.h>
@@ -118,6 +119,149 @@ int motionErrorLumaInt_sve( const Pel* org, const ptrdiff_t origStride, const Pe
118119
return error;
119120
}
120121

122+
template<FilterCoeffType4 xType, FilterCoeffType4 yType>
123+
static inline int motionErrorLumaFrac_loRes2D_sve( const Pel* org, const ptrdiff_t origStride, const Pel* buf,
124+
const ptrdiff_t buffStride, int w, int h, const int16_t* xFilter,
125+
const int16_t* yFilter, const int bitDepth, const int besterror )
126+
{
127+
const Pel maxSampleValue = ( 1 << bitDepth ) - 1;
128+
129+
CHECKD( w % 8 != 0, "Width must be multiple of 8!" );
130+
CHECKD( h % 4 != 0, "Height must be multiple of 4!" );
131+
132+
const int16x4_t xf = vrshr_n_s16( vld1_s16( xFilter ), 1 );
133+
const int16x4_t yf = vrshr_n_s16( vld1_s16( yFilter ), 1 );
134+
135+
constexpr int numFilterTaps = 4;
136+
int16x8_t h_src[numFilterTaps];
137+
int16x8_t v_src[numFilterTaps + 3]; // 3 extra elements are needed because the height loop is unrolled 4 times.
138+
139+
int error = 0;
140+
141+
do
142+
{
143+
load_s16_16x8x4( buf - 1 * buffStride - 1, 1, h_src );
144+
v_src[0] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
145+
146+
load_s16_16x8x4( buf + 0 * buffStride - 1, 1, h_src );
147+
v_src[1] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
148+
149+
load_s16_16x8x4( buf + 1 * buffStride - 1, 1, h_src );
150+
v_src[2] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
151+
152+
const Pel* rowStart = buf + 2 * buffStride - 1;
153+
const Pel* origRow = org;
154+
155+
int64x2_t diffSq0 = vdupq_n_s64( 0 );
156+
int64x2_t diffSq1 = vdupq_n_s64( 0 );
157+
158+
int y = h;
159+
do
160+
{
161+
load_s16_16x8x4( rowStart + 0 * buffStride, 1, h_src );
162+
v_src[3] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
163+
164+
load_s16_16x8x4( rowStart + 1 * buffStride, 1, h_src );
165+
v_src[4] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
166+
167+
load_s16_16x8x4( rowStart + 2 * buffStride, 1, h_src );
168+
v_src[5] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
169+
170+
load_s16_16x8x4( rowStart + 3 * buffStride, 1, h_src );
171+
v_src[6] = motionErrorLumaFrac_loRes1D_neon<xType>( h_src, xf, maxSampleValue );
172+
173+
int16x8_t ysum0 = motionErrorLumaFrac_loRes1D_neon<yType>( &v_src[0], yf, maxSampleValue );
174+
int16x8_t ysum1 = motionErrorLumaFrac_loRes1D_neon<yType>( &v_src[1], yf, maxSampleValue );
175+
int16x8_t ysum2 = motionErrorLumaFrac_loRes1D_neon<yType>( &v_src[2], yf, maxSampleValue );
176+
int16x8_t ysum3 = motionErrorLumaFrac_loRes1D_neon<yType>( &v_src[3], yf, maxSampleValue );
177+
178+
int16x8_t orig0 = vld1q_s16( origRow + 0 * origStride );
179+
int16x8_t orig1 = vld1q_s16( origRow + 1 * origStride );
180+
int16x8_t orig2 = vld1q_s16( origRow + 2 * origStride );
181+
int16x8_t orig3 = vld1q_s16( origRow + 3 * origStride );
182+
183+
int16x8_t diff0 = vabdq_s16( ysum0, orig0 );
184+
int16x8_t diff1 = vabdq_s16( ysum1, orig1 );
185+
int16x8_t diff2 = vabdq_s16( ysum2, orig2 );
186+
int16x8_t diff3 = vabdq_s16( ysum3, orig3 );
187+
188+
diffSq0 = vvenc_sdotq_s16( diffSq0, diff0, diff0 );
189+
diffSq0 = vvenc_sdotq_s16( diffSq0, diff1, diff1 );
190+
diffSq1 = vvenc_sdotq_s16( diffSq1, diff2, diff2 );
191+
diffSq1 = vvenc_sdotq_s16( diffSq1, diff3, diff3 );
192+
193+
v_src[0] = v_src[4];
194+
v_src[1] = v_src[5];
195+
v_src[2] = v_src[6];
196+
197+
rowStart += 4 * buffStride;
198+
origRow += 4 * origStride;
199+
y -= 4;
200+
} while( y != 0 );
201+
202+
int64x2_t diffSq = vaddq_s64( diffSq0, diffSq1 );
203+
error += ( int32_t )vaddvq_s64( diffSq );
204+
if( error > besterror )
205+
{
206+
return error;
207+
}
208+
209+
buf += 8;
210+
org += 8;
211+
w -= 8;
212+
} while( w != 0 );
213+
214+
return error;
215+
}
216+
217+
template<FilterCoeffType4 xType>
218+
static inline auto get_motionErrorLumaFrac2D( FilterCoeffType4 type )
219+
{
220+
switch( type )
221+
{
222+
case FilterCoeffType4::SkewLeft:
223+
return &motionErrorLumaFrac_loRes2D_sve<xType, FilterCoeffType4::SkewLeft>;
224+
case FilterCoeffType4::SkewRight:
225+
return &motionErrorLumaFrac_loRes2D_sve<xType, FilterCoeffType4::SkewRight>;
226+
case FilterCoeffType4::FullSymmetric:
227+
return &motionErrorLumaFrac_loRes2D_sve<xType, FilterCoeffType4::FullSymmetric>;
228+
case FilterCoeffType4::Generic:
229+
default:
230+
return &motionErrorLumaFrac_loRes2D_sve<xType, FilterCoeffType4::Generic>;
231+
}
232+
}
233+
234+
int motionErrorLumaFrac_loRes_sve( const Pel* org, const ptrdiff_t origStride, const Pel* buf,
235+
const ptrdiff_t buffStride, const int w, const int h, const int16_t* xFilter,
236+
const int16_t* yFilter, const int bitDepth, const int besterror )
237+
{
238+
const FilterCoeffType4 xType = selectFilterType4( xFilter );
239+
const FilterCoeffType4 yType = selectFilterType4( yFilter );
240+
241+
using motionErrorLumaFrac_loResFunc = int ( * )( const Pel*, const ptrdiff_t, const Pel*, const ptrdiff_t, const int,
242+
const int, const int16_t*, const int16_t*, const int, const int );
243+
motionErrorLumaFrac_loResFunc func;
244+
245+
switch( xType )
246+
{
247+
case FilterCoeffType4::SkewLeft:
248+
func = get_motionErrorLumaFrac2D<FilterCoeffType4::SkewLeft>( yType );
249+
break;
250+
case FilterCoeffType4::SkewRight:
251+
func = get_motionErrorLumaFrac2D<FilterCoeffType4::SkewRight>( yType );
252+
break;
253+
case FilterCoeffType4::FullSymmetric:
254+
func = get_motionErrorLumaFrac2D<FilterCoeffType4::FullSymmetric>( yType );
255+
break;
256+
case FilterCoeffType4::Generic:
257+
default:
258+
func = get_motionErrorLumaFrac2D<FilterCoeffType4::Generic>( yType );
259+
break;
260+
}
261+
262+
return func( org, origStride, buf, buffStride, w, h, xFilter, yFilter, bitDepth, besterror );
263+
}
264+
121265
void applyPlanarCorrection_sve( const Pel* refPel, const ptrdiff_t refStride, Pel* dstPel, const ptrdiff_t dstStride,
122266
const int32_t w, const int32_t h, const ClpRng& clpRng, const uint16_t motionError )
123267
{
@@ -406,6 +550,7 @@ template<>
406550
void MCTF::_initMCTF_ARM<SVE>()
407551
{
408552
m_motionErrorLumaInt8 = motionErrorLumaInt_sve;
553+
m_motionErrorLumaFrac8[1] = motionErrorLumaFrac_loRes_sve;
409554
m_applyPlanarCorrection = applyPlanarCorrection_sve;
410555
m_applyBlock = applyBlock_sve;
411556
}

0 commit comments

Comments
 (0)