Skip to content

Commit cfbf796

Browse files
committed
Add unit tests for InterPrediction xFpBiDirOptFlow
We have both x86 and Arm implementations of this kernel, so add some simple unit tests to ensure that the behaviour matches between them. This requires us to extract the C fallback implementation for xFpBiDirOptFlow to its own function and wire it up as the default function pointer. As part of this we also need to move xRightShiftMSB to be a non-member function so it is callable from the new fallback implementation. Similar to other unit tests, we add a new `enableOpt` parameter to the `InterPredInterpolation::init` function to allow us to construct both a reference and optimized object to compare against in the test.
1 parent 74a5c97 commit cfbf796

File tree

3 files changed

+180
-67
lines changed

3 files changed

+180
-67
lines changed

source/Lib/CommonLib/InterPrediction.cpp

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,10 @@ void gradFilterCore(const Pel* pSrc, int srcStride, int width, int height, int g
154154
}
155155
}
156156

157-
void calcBDOFSumsCore(const Pel* srcY0Tmp, const Pel* srcY1Tmp, Pel* gradX0, Pel* gradX1, Pel* gradY0, Pel* gradY1, int xu, int yu, const int src0Stride, const int src1Stride, const int widthG, const int bitDepth, int* sumAbsGX, int* sumAbsGY, int* sumDIX, int* sumDIY, int* sumSignGY_GX)
157+
void calcBDOFSumsCore( const Pel* srcY0Tmp, const Pel* srcY1Tmp, const Pel* gradX0, const Pel* gradX1,
158+
const Pel* gradY0, const Pel* gradY1, int xu, int yu, const int src0Stride, const int src1Stride,
159+
const int widthG, const int bitDepth, int* sumAbsGX, int* sumAbsGY, int* sumDIX, int* sumDIY,
160+
int* sumSignGY_GX )
158161
{
159162
int shift4 = 4;
160163
int shift5 = 1;
@@ -596,6 +599,68 @@ void InterPrediction::xSubPuMC(CodingUnit& cu, PelUnitBuf& predBuf, const RefPic
596599
cu.affine = isAffine;
597600
}
598601

602+
static inline int xRightShiftMSB( int numer, int denom )
603+
{
604+
return numer >> floorLog2( denom );
605+
}
606+
607+
void xFpBiDirOptFlowCore( const Pel* srcY0, const Pel* srcY1, const Pel* gradX0, const Pel* gradX1, const Pel* gradY0,
608+
const Pel* gradY1, const int width, const int height, Pel* dstY, const ptrdiff_t dstStride,
609+
const int shiftNum, const int offset, const int limit, const ClpRng& clpRng,
610+
const int bitDepth )
611+
{
612+
int xUnit = width >> 2;
613+
int yUnit = height >> 2;
614+
int heightG = height + 2 * BDOF_EXTEND_SIZE;
615+
int widthG = width + 2 * BDOF_EXTEND_SIZE;
616+
617+
int offsetPos = widthG * BDOF_EXTEND_SIZE + BDOF_EXTEND_SIZE;
618+
int stridePredMC = widthG + 2;
619+
620+
const int src0Stride = stridePredMC;
621+
const int src1Stride = stridePredMC;
622+
623+
const Pel* srcY0Temp = srcY0;
624+
const Pel* srcY1Temp = srcY1;
625+
626+
for( int yu = 0; yu < yUnit; yu++ )
627+
{
628+
for( int xu = 0; xu < xUnit; xu++ )
629+
{
630+
int tmpx = 0, tmpy = 0;
631+
int sumAbsGX = 0, sumAbsGY = 0, sumDIX = 0, sumDIY = 0;
632+
int sumSignGY_GX = 0;
633+
634+
const Pel* pGradX0Tmp = gradX0 + ( xu << 2 ) + ( yu << 2 ) * widthG;
635+
const Pel* pGradX1Tmp = gradX1 + ( xu << 2 ) + ( yu << 2 ) * widthG;
636+
const Pel* pGradY0Tmp = gradY0 + ( xu << 2 ) + ( yu << 2 ) * widthG;
637+
const Pel* pGradY1Tmp = gradY1 + ( xu << 2 ) + ( yu << 2 ) * widthG;
638+
const Pel* SrcY1Tmp = srcY1 + ( xu << 2 ) + ( yu << 2 ) * src1Stride;
639+
const Pel* SrcY0Tmp = srcY0 + ( xu << 2 ) + ( yu << 2 ) * src0Stride;
640+
641+
calcBDOFSumsCore( SrcY0Tmp, SrcY1Tmp, pGradX0Tmp, pGradX1Tmp, pGradY0Tmp, pGradY1Tmp, xu, yu, src0Stride,
642+
src1Stride, widthG, bitDepth, &sumAbsGX, &sumAbsGY, &sumDIX, &sumDIY, &sumSignGY_GX );
643+
tmpx = ( sumAbsGX == 0 ? 0 : xRightShiftMSB( 4 * sumDIX, sumAbsGX ) );
644+
tmpx = Clip3( -limit, limit, tmpx );
645+
646+
const int tmpData = sumSignGY_GX * tmpx >> 1;
647+
tmpy = ( sumAbsGY == 0 ? 0 : xRightShiftMSB( ( 4 * sumDIY - tmpData ), sumAbsGY ) );
648+
tmpy = Clip3( -limit, limit, tmpy );
649+
650+
srcY0Temp = srcY0 + ( stridePredMC + 1 ) + ( ( yu * src0Stride + xu ) << 2 );
651+
srcY1Temp = srcY1 + ( stridePredMC + 1 ) + ( ( yu * src0Stride + xu ) << 2 );
652+
pGradX0Tmp = gradX0 + offsetPos + ( ( yu * widthG + xu ) << 2 );
653+
pGradX1Tmp = gradX1 + offsetPos + ( ( yu * widthG + xu ) << 2 );
654+
pGradY0Tmp = gradY0 + offsetPos + ( ( yu * widthG + xu ) << 2 );
655+
pGradY1Tmp = gradY1 + offsetPos + ( ( yu * widthG + xu ) << 2 );
656+
657+
Pel* dstY0 = dstY + ( ( yu * dstStride + xu ) << 2 );
658+
addBDOFAvgCore( srcY0Temp, src0Stride, srcY1Temp, src1Stride, dstY0, dstStride, pGradX0Tmp, pGradX1Tmp,
659+
pGradY0Tmp, pGradY1Tmp, widthG, ( 1 << 2 ), ( 1 << 2 ), tmpx, tmpy, shiftNum, offset, clpRng );
660+
} // xu
661+
} // yu
662+
}
663+
599664
InterPredInterpolation::InterPredInterpolation()
600665
: m_storedMv(nullptr)
601666
, m_skipPROF(false)
@@ -639,7 +704,7 @@ void InterPredInterpolation::destroy()
639704
}
640705
}
641706

642-
void InterPredInterpolation::init()
707+
void InterPredInterpolation::init( bool enableOpt )
643708
{
644709
for( uint32_t c = 0; c < MAX_NUM_COMP; c++ )
645710
{
@@ -672,18 +737,21 @@ void InterPredInterpolation::init()
672737

673738
m_if.initInterpolationFilter( true );
674739

740+
xFpBiDirOptFlow = xFpBiDirOptFlowCore;
675741
xFpBDOFGradFilter = gradFilterCore;
676742
xFpProfGradFilter = gradFilterCore<false>;
677743
xFpApplyPROF = applyPROFCore;
678744
xFpPadDmvr = padDmvrCore;
679745

746+
if( enableOpt )
747+
{
680748
#if ENABLE_SIMD_OPT_BDOF && defined( TARGET_SIMD_X86 )
681-
initInterPredictionX86();
749+
initInterPredictionX86();
682750
#endif
683-
684751
#if ENABLE_SIMD_OPT_BDOF && defined( TARGET_SIMD_ARM )
685-
initInterPredictionARM();
752+
initInterPredictionARM();
686753
#endif
754+
}
687755

688756
if (m_storedMv == nullptr)
689757
{
@@ -841,11 +909,6 @@ void InterPredInterpolation::xPredInterBlk( const ComponentID compID, const Codi
841909
}
842910
}
843911

844-
int InterPredInterpolation::xRightShiftMSB( int numer, int denom )
845-
{
846-
return ( numer >> floorLog2( denom ) );
847-
}
848-
849912
void InterPredInterpolation::xApplyBDOF( PelBuf& yuvDst, const ClpRng& clpRng )
850913
{
851914
const int bitDepth = clpRng.bd;
@@ -868,9 +931,7 @@ void InterPredInterpolation::xApplyBDOF( PelBuf& yuvDst, const ClpRng& clpRng )
868931
const int src1Stride = stridePredMC;
869932

870933
Pel* dstY = yuvDst.buf;
871-
const int dstStride = yuvDst.stride;
872-
const Pel* srcY0Temp = srcY0;
873-
const Pel* srcY1Temp = srcY1;
934+
const int dstStride = yuvDst.stride;
874935

875936
for (int refList = 0; refList < NUM_REF_PIC_LIST_01; refList++)
876937
{
@@ -896,53 +957,8 @@ void InterPredInterpolation::xApplyBDOF( PelBuf& yuvDst, const ClpRng& clpRng )
896957
const int offset = (1 << (shiftNum - 1)) + 2 * IF_INTERNAL_OFFS;
897958
const int limit = (1 << 4) - 1;
898959

899-
if( xFpBiDirOptFlow )
900-
{
901-
xFpBiDirOptFlow( srcY0, srcY1, gradX0, gradX1, gradY0, gradY1, width, height, dstY, dstStride, shiftNum, offset, limit, clpRng, bitDepth );
902-
return;
903-
}
904-
905-
int xUnit = (width >> 2);
906-
int yUnit = (height >> 2);
907-
908-
Pel* dstY0 = dstY;
909-
gradX0 = m_gradX0; gradX1 = m_gradX1;
910-
gradY0 = m_gradY0; gradY1 = m_gradY1;
911-
912-
for (int yu = 0; yu < yUnit; yu++)
913-
{
914-
for (int xu = 0; xu < xUnit; xu++)
915-
{
916-
int tmpx = 0, tmpy = 0;
917-
int sumAbsGX = 0, sumAbsGY = 0, sumDIX = 0, sumDIY = 0;
918-
int sumSignGY_GX = 0;
919-
920-
Pel* pGradX0Tmp = m_gradX0 + (xu << 2) + (yu << 2) * widthG;
921-
Pel* pGradX1Tmp = m_gradX1 + (xu << 2) + (yu << 2) * widthG;
922-
Pel* pGradY0Tmp = m_gradY0 + (xu << 2) + (yu << 2) * widthG;
923-
Pel* pGradY1Tmp = m_gradY1 + (xu << 2) + (yu << 2) * widthG;
924-
const Pel* SrcY1Tmp = srcY1 + (xu << 2) + (yu << 2) * src1Stride;
925-
const Pel* SrcY0Tmp = srcY0 + (xu << 2) + (yu << 2) * src0Stride;
926-
927-
calcBDOFSumsCore(SrcY0Tmp, SrcY1Tmp, pGradX0Tmp, pGradX1Tmp, pGradY0Tmp, pGradY1Tmp, xu, yu, src0Stride, src1Stride, widthG, bitDepth, &sumAbsGX, &sumAbsGY, &sumDIX, &sumDIY, &sumSignGY_GX);
928-
tmpx = (sumAbsGX == 0 ? 0 : xRightShiftMSB(4 * sumDIX, sumAbsGX));
929-
tmpx = Clip3(-limit, limit, tmpx);
930-
931-
const int tmpData = sumSignGY_GX * tmpx >> 1;
932-
tmpy = (sumAbsGY == 0 ? 0 : xRightShiftMSB((4 * sumDIY - tmpData), sumAbsGY));
933-
tmpy = Clip3(-limit, limit, tmpy);
934-
935-
srcY0Temp = srcY0 + (stridePredMC + 1) + ((yu*src0Stride + xu) << 2);
936-
srcY1Temp = srcY1 + (stridePredMC + 1) + ((yu*src0Stride + xu) << 2);
937-
gradX0 = m_gradX0 + offsetPos + ((yu*widthG + xu) << 2);
938-
gradX1 = m_gradX1 + offsetPos + ((yu*widthG + xu) << 2);
939-
gradY0 = m_gradY0 + offsetPos + ((yu*widthG + xu) << 2);
940-
gradY1 = m_gradY1 + offsetPos + ((yu*widthG + xu) << 2);
941-
942-
dstY0 = dstY + ((yu*dstStride + xu) << 2);
943-
addBDOFAvgCore(srcY0Temp, src0Stride, srcY1Temp, src1Stride, dstY0, dstStride, gradX0, gradX1, gradY0, gradY1, widthG, (1 << 2), (1 << 2), tmpx, tmpy, shiftNum, offset, clpRng);
944-
} // xu
945-
} // yu
960+
xFpBiDirOptFlow( srcY0, srcY1, gradX0, gradX1, gradY0, gradY1, width, height, dstY, dstStride, shiftNum, offset,
961+
limit, clpRng, bitDepth );
946962
}
947963

948964
void InterPredInterpolation::xWeightedAverage( const CodingUnit& cu, const CPelUnitBuf& pcYuvSrc0, const CPelUnitBuf& pcYuvSrc1, PelUnitBuf& pcYuvDst, const bool bdofApplied, PelUnitBuf *yuvPredTmp )

source/Lib/CommonLib/InterPrediction.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ class InterPredInterpolation
8989
Pel* m_filteredBlockTmp [LUMA_INTERPOLATION_FILTER_SUB_SAMPLE_POSITIONS_SIGNAL][MAX_NUM_COMP];
9090
int m_ifpLines;
9191

92-
int xRightShiftMSB ( int numer, int denom );
9392
void xApplyBDOF ( PelBuf& yuvDst, const ClpRng& clpRng );
93+
94+
public:
9495
void(*xFpBiDirOptFlow) ( const Pel* srcY0, const Pel* srcY1, const Pel* gradX0, const Pel* gradX1, const Pel* gradY0, const Pel* gradY1, const int width, const int height, Pel* dstY, const ptrdiff_t dstStride, const int shiftNum, const int offset, const int limit, const ClpRng& clpRng, const int bitDepth ) = nullptr;
9596
void(*xFpBDOFGradFilter) ( const Pel* pSrc, int srcStride, int width, int height, int gradStride, Pel* gradX, Pel* gradY, const int bitDepth );
9697
void(*xFpProfGradFilter) ( const Pel* pSrc, int srcStride, int width, int height, int gradStride, Pel* gradX, Pel* gradY, const int bitDepth );
9798
void(*xFpApplyPROF) ( Pel* dst, int dstStride, const Pel* src, int srcStride, int width, int height, const Pel* gradX, const Pel* gradY, int gradStride, const int* dMvX, const int* dMvY, int dMvStride, const bool& bi, int shiftNum, Pel offset, const ClpRng& clpRng );
9899
void(*xFpPadDmvr) ( const Pel* src, const int srcStride, Pel* dst, const int dstStride, int width, int height, int padSize );
99100

101+
protected:
100102
#if ENABLE_SIMD_OPT_BDOF && defined( TARGET_SIMD_X86 )
101103
void initInterPredictionX86();
102104
template <X86_VEXT vext>
@@ -109,7 +111,6 @@ class InterPredInterpolation
109111
void _initInterPredictionARM();
110112
#endif
111113

112-
protected:
113114
void xWeightedAverage ( const CodingUnit& cu, const CPelUnitBuf& pcYuvSrc0, const CPelUnitBuf& pcYuvSrc1, PelUnitBuf& pcYuvDst, const bool bdofApplied, PelUnitBuf *yuvPredTmp = NULL );
114115
void xPredAffineBlk ( const ComponentID compID, const CodingUnit& cu, const Picture* refPic, const Mv* _mv, PelUnitBuf& dstPic, const bool bi, const ClpRng& clpRng, const RefPicList refPicList = REF_PIC_LIST_X);
115116
void xPredInterBlk( const ComponentID compID, const CodingUnit& cu, const Picture* refPic, const Mv& _mv, PelUnitBuf& dstPic, const bool bi, const ClpRng& clpRng
@@ -126,8 +127,8 @@ class InterPredInterpolation
126127
public:
127128
InterPredInterpolation();
128129
virtual ~InterPredInterpolation();
129-
void destroy ();
130-
void init ();
130+
void destroy();
131+
void init( bool enableOpt = true );
131132

132133
void weightedGeoBlk ( const ClpRngs &clpRngs, CodingUnit& cu, const uint8_t splitDir, int32_t channel,
133134
PelUnitBuf &predDst, PelUnitBuf &predSrc0, PelUnitBuf &predSrc1);

test/vvenc_unit_test/vvenc_unit_test.cpp

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ POSSIBILITY OF SUCH DAMAGE.
5252
#include <stdio.h>
5353
#include <time.h>
5454

55+
#include "CommonLib/InterPrediction.h"
5556
#include "CommonLib/MCTF.h"
5657
#include "CommonLib/TrQuant_EMT.h"
5758
#include "CommonLib/TypeDef.h"
@@ -71,8 +72,9 @@ static inline bool compare_value( const std::string& context, const T ref, const
7172
return opt == ref;
7273
}
7374

74-
static inline bool compare_values_2d( const std::string& context, const TCoeff* ref, const TCoeff* opt, unsigned rows,
75-
unsigned cols, unsigned stride = 0 )
75+
template<typename T>
76+
static inline bool compare_values_2d( const std::string& context, const T* ref, const T* opt, unsigned rows,
77+
unsigned cols, unsigned stride = 0 )
7678
{
7779
stride = stride != 0 ? stride : cols;
7880

@@ -360,6 +362,98 @@ static bool test_MCTF()
360362
}
361363
#endif
362364

365+
#if ENABLE_SIMD_OPT_BDOF
366+
template<typename G>
367+
static bool check_one_biDirOptFlow( InterPredInterpolation* ref, InterPredInterpolation* opt, int width, int height,
368+
ptrdiff_t dstStride, int shift, int offset, int limit, ClpRng clpRng, int bitDepth,
369+
G input_generator )
370+
{
371+
CHECK( width % 8, "Width must be a multiple of eight" );
372+
CHECK( height % 8, "Height must be a multiple of eight" );
373+
374+
std::ostringstream sstm;
375+
sstm << "biDirOptFlow width=" << width << " height=" << height << " shift=" << shift << " offset=" << offset
376+
<< " limit=" << limit;
377+
378+
int srcStride = width + 2 * BDOF_EXTEND_SIZE + 2;
379+
int gradStride = width + 2;
380+
381+
std::vector<Pel> srcY0( srcStride * ( height + 2 ) );
382+
std::vector<Pel> srcY1( srcStride * ( height + 2 ) );
383+
std::vector<Pel> gradX0( gradStride * ( height + 2 ) );
384+
std::vector<Pel> gradX1( gradStride * ( height + 2 ) );
385+
std::vector<Pel> gradY0( gradStride * ( height + 2 ) );
386+
std::vector<Pel> gradY1( gradStride * ( height + 2 ) );
387+
std::vector<Pel> dstYref( dstStride * height );
388+
std::vector<Pel> dstYopt( dstStride * height );
389+
390+
// Initialize source buffers.
391+
std::generate( srcY0.begin(), srcY0.end(), input_generator );
392+
std::generate( srcY1.begin(), srcY1.end(), input_generator );
393+
std::generate( gradX0.begin(), gradX0.end(), input_generator );
394+
std::generate( gradX1.begin(), gradX1.end(), input_generator );
395+
std::generate( gradY0.begin(), gradY0.end(), input_generator );
396+
std::generate( gradY1.begin(), gradY1.end(), input_generator );
397+
398+
ref->xFpBiDirOptFlow( srcY0.data(), srcY1.data(), gradX0.data(), gradX1.data(), gradY0.data(), gradY1.data(), width,
399+
height, dstYref.data(), dstStride, shift, offset, limit, clpRng, bitDepth );
400+
opt->xFpBiDirOptFlow( srcY0.data(), srcY1.data(), gradX0.data(), gradX1.data(), gradY0.data(), gradY1.data(), width,
401+
height, dstYopt.data(), dstStride, shift, offset, limit, clpRng, bitDepth );
402+
return compare_values_2d( sstm.str(), dstYref.data(), dstYopt.data(), height, dstStride );
403+
}
404+
405+
static bool check_biDirOptFlow( InterPredInterpolation* ref, InterPredInterpolation* opt, unsigned num_cases, int width,
406+
int height )
407+
{
408+
printf( "Testing InterPred::xFpBiDirOptFlow w=%d h=%d\n", width, height );
409+
InputGenerator<TCoeff> g{ 10 };
410+
DimensionGenerator rng;
411+
412+
for( unsigned i = 0; i < num_cases; ++i )
413+
{
414+
// Width is either 8 or 16.
415+
// DstStride is a multiple of eight in the range width to 128 inclusive.
416+
unsigned dstStride = rng.get( width, 128, 8 );
417+
418+
for( int bitDepth = 8; bitDepth <= 10; bitDepth += 2 )
419+
{
420+
const unsigned shift = IF_INTERNAL_PREC + 1 - bitDepth;
421+
const int offset = ( 1 << ( shift - 1 ) ) + 2 * IF_INTERNAL_OFFS;
422+
const int limit = ( 1 << 4 ) - 1;
423+
ClpRng clpRng{ bitDepth };
424+
425+
if( !check_one_biDirOptFlow( ref, opt, width, height, dstStride, shift, offset, limit, clpRng, bitDepth, g ) )
426+
{
427+
return false;
428+
}
429+
}
430+
}
431+
432+
return true;
433+
}
434+
435+
static bool test_InterPred()
436+
{
437+
InterPredInterpolation ref;
438+
InterPredInterpolation opt;
439+
440+
ref.init( /*enableOpt=*/false );
441+
opt.init( /*enableOpt=*/true );
442+
443+
unsigned num_cases = NUM_CASES;
444+
bool passed = true;
445+
446+
for( unsigned width = 8; width <= 16; width += 8 )
447+
{
448+
for( unsigned height = 8; height <= 16; height += 8 )
449+
{
450+
passed = check_biDirOptFlow( &ref, &opt, num_cases, width, height ) && passed;
451+
}
452+
}
453+
return passed;
454+
}
455+
#endif
456+
363457
int main( int argc, char** argv )
364458
{
365459
unsigned seed = ( unsigned ) time( NULL );
@@ -370,10 +464,12 @@ int main( int argc, char** argv )
370464
#if ENABLE_SIMD_TRAFO
371465
passed = test_TCoeffOps() && passed;
372466
#endif
373-
374467
#if ENABLE_SIMD_OPT_MCTF
375468
passed = test_MCTF() && passed;
376469
#endif
470+
#if ENABLE_SIMD_OPT_BDOF
471+
passed = test_InterPred() && passed;
472+
#endif
377473

378474
if( !passed )
379475
{

0 commit comments

Comments
 (0)