diff --git a/test/vvenc_unit_test/vvenc_unit_test.cpp b/test/vvenc_unit_test/vvenc_unit_test.cpp index d05a86aa..d88157a9 100644 --- a/test/vvenc_unit_test/vvenc_unit_test.cpp +++ b/test/vvenc_unit_test/vvenc_unit_test.cpp @@ -1299,6 +1299,46 @@ static bool check_SAD( RdCost* ref, RdCost* opt, unsigned num_cases, int width, return passed; } +static bool check_HADs( RdCost* ref, RdCost* opt, unsigned num_cases, int width, int height, bool fast ) +{ + std::ostringstream sstm; + const char* fast_str = fast ? "_fast" : ""; + sstm << "RdCost::m_afpDistortFunc[0][DF_HAD" << width << fast_str << "] " + << " w=" << width << " h=" << height; + printf( "Testing %s\n", sstm.str().c_str() ); + + DimensionGenerator rng; + InputGenerator g10{ 10, /*is_signed=*/false }; + + bool passed = true; + for( unsigned i = 0; i < num_cases; i++ ) + { + int org_stride = rng.get( width, g_fastUnitTest ? 256 : 1024 ); + int cur_stride = rng.get( width, g_fastUnitTest ? 256 : 1024 ); + std::vector orgBuf( org_stride * height ); + std::vector curBuf( cur_stride * height ); + + DistParam dtParam; + dtParam.org.buf = orgBuf.data(); + dtParam.org.stride = org_stride; + dtParam.cur.buf = curBuf.data(); + dtParam.cur.stride = cur_stride; + dtParam.org.width = width; + dtParam.org.height = height; + dtParam.bitDepth = 10; + dtParam.applyWeight = 0; // applyWeight appears to be always zero. + + std::generate( orgBuf.begin(), orgBuf.end(), g10 ); + std::generate( curBuf.begin(), curBuf.end(), g10 ); + + const int index = ( fast ? DF_HAD_fast : DF_HAD ) + log2( width ); + Distortion sum_ref = ref->m_afpDistortFunc[0][( DFunc )( index )]( dtParam ); + Distortion sum_opt = opt->m_afpDistortFunc[0][( DFunc )( index )]( dtParam ); + passed = compare_value( sstm.str(), sum_ref, sum_opt ) && passed; + } + return passed; +} + static bool check_SADwMask( RdCost* ref, RdCost* opt, unsigned num_cases, int width, int height ) { std::ostringstream sstm; @@ -1366,6 +1406,12 @@ static bool test_RdCost() passed = check_fixWeightedSSE( &ref, &opt, num_cases, w, h ) && passed; passed = check_SAD( &ref, &opt, num_cases, w, h ) && passed; + if( w >= 2 ) + { + passed = check_HADs( &ref, &opt, num_cases, w, h, /*fast=*/true ) && passed; + passed = check_HADs( &ref, &opt, num_cases, w, h, /*fast=*/false ) && passed; + } + if (w >= 8 && h >= 8) { passed = check_SADwMask( &ref, &opt, num_cases, w, h ) && passed;