Skip to content

Commit 90ff897

Browse files
Configure unit tests to run more iterations.
1 parent 6623d16 commit 90ff897

File tree

1 file changed

+52
-17
lines changed

1 file changed

+52
-17
lines changed

onnxruntime/test/mlas/unittest/test_conv2d_fixture.h

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,58 @@ class Conv2dShortExecuteTest : public MlasTestFixture<Conv2dTester> {
4646
}
4747

4848
void TestBody() override {
49-
MlasTestFixture<Conv2dTester>::mlas_tester->Test(
50-
BatchCount_,
51-
GroupCount_,
52-
InputChannels_,
53-
InputHeight_,
54-
InputWidth_,
55-
FilterCount_,
56-
KernelHeight_,
57-
KernelWidth_,
58-
PaddingLeftHeight_,
59-
PaddingLeftWidth_,
60-
PaddingRightHeight_,
61-
PaddingRightWidth_,
62-
DilationHeight_,
63-
DilationWidth_,
64-
StrideHeight_,
65-
StrideWidth_);
49+
const char* iter_env = std::getenv("MLAS_BENCHMARK_ITERATIONS");
50+
int benchmark_iterations = iter_env ? std::atoi(iter_env) : 100;
51+
int warmup_iterations = benchmark_iterations > 1 ? 10 : 0;
52+
53+
if (benchmark_iterations > 1) {
54+
for (int i = 0; i < warmup_iterations; i++) {
55+
MlasTestFixture<Conv2dTester>::mlas_tester->Test(
56+
BatchCount_, GroupCount_, InputChannels_,
57+
InputHeight_, InputWidth_, FilterCount_,
58+
KernelHeight_, KernelWidth_,
59+
PaddingLeftHeight_, PaddingLeftWidth_,
60+
PaddingRightHeight_, PaddingRightWidth_,
61+
DilationHeight_, DilationWidth_,
62+
StrideHeight_, StrideWidth_);
63+
}
64+
65+
auto start = std::chrono::high_resolution_clock::now();
66+
67+
for (int i = 0; i < benchmark_iterations; i++) {
68+
MlasTestFixture<Conv2dTester>::mlas_tester->Test(
69+
BatchCount_, GroupCount_, InputChannels_,
70+
InputHeight_, InputWidth_, FilterCount_,
71+
KernelHeight_, KernelWidth_,
72+
PaddingLeftHeight_, PaddingLeftWidth_,
73+
PaddingRightHeight_, PaddingRightWidth_,
74+
DilationHeight_, DilationWidth_,
75+
StrideHeight_, StrideWidth_);
76+
}
77+
78+
auto end = std::chrono::high_resolution_clock::now();
79+
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
80+
81+
double avg_time_us = duration.count() / double(benchmark_iterations);
82+
double ops_per_sec = 1000000.0 / avg_time_us;
83+
84+
printf(
85+
"BENCHMARK: B%zu/G%zu/IC%zu/FC%zu/IH%zu/IW%zu/KH%zu/KW%zu/Pad%zu,%zu,%zu,%zu/Dilation%zu,%zu/Stride%zu,%zu - "
86+
"Avg: %.2f μs, Rate: %.2f ops/sec (%d iterations)\n",
87+
BatchCount_, GroupCount_, InputChannels_, FilterCount_, InputHeight_, InputWidth_,
88+
KernelHeight_, KernelWidth_, PaddingLeftHeight_, PaddingLeftWidth_, PaddingRightHeight_,
89+
PaddingRightWidth_, DilationHeight_, DilationWidth_, StrideHeight_, StrideWidth_,
90+
avg_time_us, ops_per_sec, benchmark_iterations);
91+
} else {
92+
MlasTestFixture<Conv2dTester>::mlas_tester->Test(
93+
BatchCount_, GroupCount_, InputChannels_,
94+
InputHeight_, InputWidth_, FilterCount_,
95+
KernelHeight_, KernelWidth_,
96+
PaddingLeftHeight_, PaddingLeftWidth_,
97+
PaddingRightHeight_, PaddingRightWidth_,
98+
DilationHeight_, DilationWidth_,
99+
StrideHeight_, StrideWidth_);
100+
}
66101
}
67102

68103
static size_t RegisterSingleTest(

0 commit comments

Comments
 (0)