Skip to content

Commit bbcfa94

Browse files
authored
Merge pull request #582 from athulya-arm/athulya-arm/gradFilter
Add Neon implementation for gradFilter
2 parents 1822bfd + 69b4b63 commit bbcfa94

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

source/Lib/CommonLib/arm/InterPredARM.h

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,120 @@ void BiOptFlowCoreARMSIMD( const Pel* srcY0, const Pel* srcY1, const Pel* gradX0
206206
}
207207
}
208208

209+
template<bool PAD = true>
210+
void gradFilter_neon( const Pel* pSrc, int srcStride, int width, int height, int gradStride, Pel* gradX, Pel* gradY,
211+
const int bitDepth )
212+
{
213+
const Pel* srcTmp = pSrc + srcStride + 1;
214+
Pel* gradXTmp = gradX + gradStride + 1;
215+
Pel* gradYTmp = gradY + gradStride + 1;
216+
int widthInside = width - 2 * BDOF_EXTEND_SIZE;
217+
int heightInside = height - 2 * BDOF_EXTEND_SIZE;
218+
static constexpr int shift = 6;
219+
220+
CHECK( widthInside < 4, "(Width - 2) must be greater than or equal to 4!" );
221+
CHECK( heightInside % 2 != 0, "(Height - 2) must be multiple of 2!" );
222+
223+
if( widthInside % 8 == 0 )
224+
{
225+
int y = heightInside;
226+
do
227+
{
228+
int x = widthInside;
229+
do
230+
{
231+
int16x8_t srcRight = vld1q_s16( srcTmp + 1 );
232+
int16x8_t srcLeft = vld1q_s16( srcTmp - 1 );
233+
234+
int16x8_t srcBottom = vld1q_s16( srcTmp + srcStride );
235+
int16x8_t srcTop = vld1q_s16( srcTmp - srcStride );
236+
237+
srcRight = vshrq_n_s16( srcRight, shift );
238+
srcLeft = vshrq_n_s16( srcLeft, shift );
239+
srcBottom = vshrq_n_s16( srcBottom, shift );
240+
srcTop = vshrq_n_s16( srcTop, shift );
241+
242+
const int16x8_t grad_x = vsubq_s16( srcRight, srcLeft );
243+
const int16x8_t grad_y = vsubq_s16( srcBottom, srcTop );
244+
245+
vst1q_s16( gradXTmp, grad_x );
246+
vst1q_s16( gradYTmp, grad_y );
247+
248+
srcTmp += 8;
249+
gradXTmp += 8;
250+
gradYTmp += 8;
251+
x -= 8;
252+
} while( x != 0 );
253+
254+
gradXTmp += gradStride - widthInside;
255+
gradYTmp += gradStride - widthInside;
256+
srcTmp += srcStride - widthInside;
257+
} while( --y != 0 );
258+
}
259+
else
260+
{
261+
CHECK( widthInside != 4, "(Width - 2) must be equal to 4!" );
262+
int y = heightInside >> 1;
263+
264+
int16x8_t srcTop = vcombine_s16( vld1_s16( srcTmp - srcStride ), vld1_s16( srcTmp ) );
265+
srcTop = vshrq_n_s16( srcTop, shift );
266+
267+
do
268+
{
269+
int16x8_t srcRight = vcombine_s16( vld1_s16( srcTmp + 1 ), vld1_s16( srcTmp + srcStride + 1 ) );
270+
int16x8_t srcLeft = vcombine_s16( vld1_s16( srcTmp - 1 ), vld1_s16( srcTmp + srcStride - 1 ) );
271+
int16x8_t srcBottom = vcombine_s16( vld1_s16( srcTmp + srcStride ), vld1_s16( srcTmp + ( srcStride << 1 ) ) );
272+
273+
srcRight = vshrq_n_s16( srcRight, shift );
274+
srcLeft = vshrq_n_s16( srcLeft, shift );
275+
srcBottom = vshrq_n_s16( srcBottom, shift );
276+
277+
const int16x8_t grad_x = vsubq_s16( srcRight, srcLeft );
278+
const int16x8_t grad_y = vsubq_s16( srcBottom, srcTop );
279+
280+
vst1_s16( gradXTmp, vget_low_s16( grad_x ) );
281+
vst1_s16( gradXTmp + gradStride, vget_high_s16( grad_x ) );
282+
vst1_s16( gradYTmp, vget_low_s16( grad_y ) );
283+
vst1_s16( gradYTmp + gradStride, vget_high_s16( grad_y ) );
284+
285+
gradXTmp += gradStride << 1;
286+
gradYTmp += gradStride << 1;
287+
srcTmp += srcStride << 1;
288+
srcTop = srcBottom; // For next iteration.
289+
} while( --y != 0 );
290+
}
291+
292+
if( PAD )
293+
{
294+
gradXTmp = gradX + gradStride + 1;
295+
gradYTmp = gradY + gradStride + 1;
296+
int y = heightInside;
297+
do
298+
{
299+
gradXTmp[-1] = gradXTmp[0];
300+
gradXTmp[widthInside] = gradXTmp[widthInside - 1];
301+
gradXTmp += gradStride;
302+
303+
gradYTmp[-1] = gradYTmp[0];
304+
gradYTmp[widthInside] = gradYTmp[widthInside - 1];
305+
gradYTmp += gradStride;
306+
} while( --y != 0 );
307+
308+
gradXTmp = gradX + gradStride;
309+
gradYTmp = gradY + gradStride;
310+
memcpy( gradXTmp - gradStride, gradXTmp, sizeof( Pel ) * width );
311+
memcpy( gradXTmp + heightInside * gradStride, gradXTmp + ( heightInside - 1 ) * gradStride, sizeof( Pel ) * width );
312+
memcpy( gradYTmp - gradStride, gradYTmp, sizeof( Pel ) * width );
313+
memcpy( gradYTmp + heightInside * gradStride, gradYTmp + ( heightInside - 1 ) * gradStride, sizeof( Pel ) * width );
314+
}
315+
}
316+
209317
template<ARM_VEXT vext>
210318
void InterPredInterpolation::_initInterPredictionARM()
211319
{
212-
xFpBiDirOptFlow = BiOptFlowCoreARMSIMD<vext>;
320+
xFpBiDirOptFlow = BiOptFlowCoreARMSIMD<vext>;
321+
xFpBDOFGradFilter = gradFilter_neon;
322+
xFpProfGradFilter = gradFilter_neon<false>;
213323
}
214324

215325
template void InterPredInterpolation::_initInterPredictionARM<SIMDARM>();

0 commit comments

Comments
 (0)