@@ -1081,6 +1081,53 @@ static bool check_fixWeightedSSE( RdCost* ref, RdCost* opt, unsigned num_cases,
10811081 return passed;
10821082}
10831083
1084+ static bool check_SADwMask ( RdCost* ref, RdCost* opt, unsigned num_cases, int width, int height )
1085+ {
1086+ std::ostringstream sstm;
1087+ sstm << " RdCost::m_afpDistortFunc[0][DF_SAD_WITH_MASK] " << " w=" << width << " h=" << height;
1088+ printf ( " Testing %s\n " , sstm.str ().c_str ());
1089+
1090+ DimensionGenerator rng;
1091+ InputGenerator<Pel> g1{ 1 , /* is_signed=*/ false }; // Masks are either 0 or 1.
1092+ InputGenerator<Pel> g10{ 10 , /* is_signed=*/ false };
1093+
1094+ bool passed = true ;
1095+ for ( unsigned i = 0 ; i < num_cases; i++ )
1096+ {
1097+ int org_stride = rng.get ( width, 1024 );
1098+ int cur_stride = rng.get ( width, 1024 );
1099+ int mask_stride = rng.get ( width, 1024 );
1100+ std::vector<Pel> orgBuf ( org_stride * height );
1101+ std::vector<Pel> curBuf ( cur_stride * height );
1102+ std::vector<Pel> maskBuf ( mask_stride * height );
1103+ bool negStepX = rng.get ( 0 , 1 ) != 0 ;
1104+
1105+ DistParam dtParam;
1106+ dtParam.org .buf = orgBuf.data ();
1107+ dtParam.org .stride = org_stride;
1108+ dtParam.cur .buf = curBuf.data ();
1109+ dtParam.cur .stride = cur_stride;
1110+ dtParam.mask = maskBuf.data () + (negStepX ? width : 0 );
1111+ dtParam.maskStride = mask_stride;
1112+ dtParam.maskStride2 = negStepX ? width : -width;
1113+ dtParam.org .width = width;
1114+ dtParam.org .height = height;
1115+ dtParam.bitDepth = 10 ;
1116+ dtParam.subShift = rng.get ( 0 , 1 );
1117+ dtParam.applyWeight = 0 ; // applyWeight appears to be always zero.
1118+ dtParam.stepX = negStepX ? -1 : 1 ;
1119+
1120+ std::generate ( orgBuf.begin (), orgBuf.end (), g10 );
1121+ std::generate ( curBuf.begin (), curBuf.end (), g10 );
1122+ std::generate ( maskBuf.begin (), maskBuf.end (), g1);
1123+
1124+ Distortion sum_ref = ref->m_afpDistortFunc [0 ][DF_SAD_WITH_MASK]( dtParam );
1125+ Distortion sum_opt = opt->m_afpDistortFunc [0 ][DF_SAD_WITH_MASK]( dtParam );
1126+ passed = compare_value ( sstm.str (), sum_ref, sum_opt ) && passed;
1127+ }
1128+ return passed;
1129+ }
1130+
10841131static bool test_RdCost ()
10851132{
10861133 RdCost ref;
@@ -1098,13 +1145,12 @@ static bool test_RdCost()
10981145 for ( int w : widths )
10991146 {
11001147 passed = check_lumaWeightedSSE ( &ref, &opt, num_cases, w, h ) && passed;
1101- }
1102- }
1103- for ( int h : heights )
1104- {
1105- for ( int w : widths )
1106- {
11071148 passed = check_fixWeightedSSE ( &ref, &opt, num_cases, w, h ) && passed;
1149+
1150+ if (w >= 8 && h >= 8 )
1151+ {
1152+ passed = check_SADwMask ( &ref, &opt, num_cases, w, h ) && passed;
1153+ }
11081154 }
11091155 }
11101156 return passed;
0 commit comments