@@ -71,6 +71,8 @@ namespace po = apputils::program_options;
7171
7272#define NUM_CASES 100
7373
74+ static bool g_fastUnitTest = false ;
75+
7476template <typename T>
7577static inline bool compare_value ( const std::string& context, const T ref, const T opt )
7678{
@@ -756,8 +758,8 @@ static bool check_applyFrac( MCTF* ref, MCTF* opt, int w, int h )
756758 for ( int yIndex = 0 ; yIndex < motionVectorFactor; ++yIndex )
757759 {
758760 // Stride is often the width of a video frame, so use the width of 8K as an upper bound.
759- unsigned orgStride = rng.get ( w, 8192 );
760- unsigned dstStride = rng.get ( w, 8192 );
761+ unsigned orgStride = rng.get ( w, g_fastUnitTest ? 512 : 8192 );
762+ unsigned dstStride = rng.get ( w, g_fastUnitTest ? 512 : 8192 );
761763 if ( !check_one_applyFrac<Ch, NumTaps>( ref, opt, orgStride, dstStride, w, h, xIndex, yIndex, bitDepth, g ) )
762764 {
763765 return false ;
@@ -880,8 +882,8 @@ static bool check_motionErrorLumaFrac8( MCTF* ref, MCTF* opt, int w, int h )
880882 for ( unsigned yIndex = 1 ; yIndex < motionVectorFactor; ++yIndex )
881883 {
882884 // Stride is often the width of a video frame, so use the width of 8K as an upper bound.
883- unsigned orgStride = rng.get ( w, 8192 );
884- unsigned bufStride = rng.get ( w, 8192 );
885+ unsigned orgStride = rng.get ( w, g_fastUnitTest ? 512 : 8192 );
886+ unsigned bufStride = rng.get ( w, g_fastUnitTest ? 512 : 8192 );
885887 unsigned besterror = INT_MAX;
886888 if ( !check_one_motionErrorLumaFrac8<lowRes>( ref, opt, orgStride, bufStride, w, h, xIndex, yIndex, bitDepth,
887889 besterror, g ) )
@@ -1179,9 +1181,9 @@ static bool check_lumaWeightedSSE( RdCost* ref, RdCost* opt, unsigned num_cases,
11791181 bool passed = true ;
11801182 for ( unsigned i = 0 ; i < num_cases; i++ )
11811183 {
1182- int org_stride = rng.get ( width, 1024 );
1183- int cur_stride = rng.get ( width, 1024 );
1184- int luma_stride = rng.get ( width, 1024 );
1184+ int org_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
1185+ int cur_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
1186+ int luma_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
11851187 std::vector<Pel> orgBuf ( org_stride * height );
11861188 std::vector<Pel> curBuf ( cur_stride * height );
11871189 std::vector<Pel> orgLumaBuf ( luma_stride * height * 2 );
@@ -1231,8 +1233,8 @@ static bool check_fixWeightedSSE( RdCost* ref, RdCost* opt, unsigned num_cases,
12311233 bool passed = true ;
12321234 for ( unsigned i = 0 ; i < num_cases; i++ )
12331235 {
1234- int org_stride = rng.get ( width, 1024 );
1235- int cur_stride = rng.get ( width, 1024 );
1236+ int org_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
1237+ int cur_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
12361238 std::vector<Pel> orgBuf ( org_stride * height );
12371239 std::vector<Pel> curBuf ( cur_stride * height );
12381240
@@ -1269,9 +1271,9 @@ static bool check_SADwMask( RdCost* ref, RdCost* opt, unsigned num_cases, int wi
12691271 bool passed = true ;
12701272 for ( unsigned i = 0 ; i < num_cases; i++ )
12711273 {
1272- int org_stride = rng.get ( width, 1024 );
1273- int cur_stride = rng.get ( width, 1024 );
1274- int mask_stride = rng.get ( width, 1024 );
1274+ int org_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
1275+ int cur_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
1276+ int mask_stride = rng.get ( width, g_fastUnitTest ? 256 : 1024 );
12751277 std::vector<Pel> orgBuf ( org_stride * height );
12761278 std::vector<Pel> curBuf ( cur_stride * height );
12771279 std::vector<Pel> maskBuf ( mask_stride * height );
@@ -1955,9 +1957,9 @@ static bool check_xWeightedGeoBlk( InterpolationFilter* ref, InterpolationFilter
19551957 {
19561958 for ( int bitDepth : { 8 , 10 } )
19571959 {
1958- unsigned src0Stride = rng.get ( width, 128 );
1959- unsigned src1Stride = rng.get ( width, 128 );
1960- unsigned dstStride = rng.get ( width, 128 );
1960+ unsigned src0Stride = rng.get ( width, MAX_CU_SIZE );
1961+ unsigned src1Stride = rng.get ( width, MAX_CU_SIZE );
1962+ unsigned dstStride = rng.get ( width, MAX_CU_SIZE );
19611963 if ( !check_one_xWeightedGeoBlk ( ref, opt, src0Stride, src1Stride, dstStride, width, height, bitDepth, g ) )
19621964 {
19631965 return false ;
@@ -2102,6 +2104,7 @@ static const UnitTestEntry test_suites[] = {
21022104
21032105struct UnitTestArgs
21042106{
2107+ bool isFast = false ;
21052108 bool show_help = false ;
21062109 int seed;
21072110 std::string testcase;
@@ -2133,11 +2136,14 @@ UnitTestArgs parse_args( int argc, char* argv[] )
21332136 opts.addOptions ()
21342137 ( " help,h" , args.show_help , " Show help" , true )
21352138 ( " seed" , args.seed , " Set random seed for running tests" )
2136- ( " testcase,t" , args.testcase , get_testcase_help_text (), false );
2139+ ( " testcase,t" , args.testcase , get_testcase_help_text (), false )
2140+ ( " fast" , args.isFast , " Run a fast but less real-world accurate version of the tests" , false );
21372141
21382142 po::SilentReporter err;
21392143 po::scanArgv ( opts, argc, ( const char ** )argv, err );
21402144
2145+ g_fastUnitTest = args.isFast ;
2146+
21412147 if ( args.show_help )
21422148 {
21432149 std::ostringstream help_sstm;
0 commit comments