Skip to content

Commit f895d35

Browse files
HanKuanChenhanhanW
andauthored
Add expected_bf16_threshold support to iree-run-module (#22547)
Previously, iree-run-module did not handle expected_bf16_threshold, so BF16 outputs could not be validated with a BF16-specific comparison threshold. Extend the tool to parse and apply expected_bf16_threshold when checking results, enabling correct tolerance handling for BF16 outputs. --------- Signed-off-by: Han-Kuan Chen <hankuan.chen@sifive.com> Signed-off-by: hanhanW <hanhan0912@gmail.com> Co-authored-by: hanhanW <hanhan0912@gmail.com>
1 parent 8223e1f commit f895d35

9 files changed

Lines changed: 230 additions & 13 deletions

runtime/src/iree/tooling/buffer_view_matchers.c

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,23 @@ static bool iree_hal_compare_strided_elements_approximate_absolute_f64(
125125
return true;
126126
}
127127

128+
static bool iree_hal_compare_strided_elements_approximate_absolute_bf16(
129+
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
130+
const uint16_t* expected_ptr, iree_host_size_t expected_stride,
131+
const uint16_t* actual_ptr, iree_host_size_t actual_stride,
132+
iree_host_size_t* out_index) {
133+
for (iree_host_size_t i = 0; i < element_count; ++i) {
134+
if (fabsf(iree_math_bf16_to_f32(*expected_ptr) -
135+
iree_math_bf16_to_f32(*actual_ptr)) > equality.bf16_threshold) {
136+
*out_index = i;
137+
return false;
138+
}
139+
expected_ptr += expected_stride;
140+
actual_ptr += actual_stride;
141+
}
142+
return true;
143+
}
144+
128145
static bool iree_hal_compare_strided_elements_approximate_relative_f16(
129146
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
130147
const uint16_t* expected_ptr, iree_host_size_t expected_stride,
@@ -177,6 +194,24 @@ static bool iree_hal_compare_strided_elements_approximate_relative_f64(
177194
return true;
178195
}
179196

197+
static bool iree_hal_compare_strided_elements_approximate_relative_bf16(
198+
iree_hal_buffer_equality_t equality, iree_host_size_t element_count,
199+
const uint16_t* expected_ptr, iree_host_size_t expected_stride,
200+
const uint16_t* actual_ptr, iree_host_size_t actual_stride,
201+
iree_host_size_t* out_index) {
202+
for (iree_host_size_t i = 0; i < element_count; ++i) {
203+
if (fabsf((iree_math_bf16_to_f32(*expected_ptr) -
204+
iree_math_bf16_to_f32(*actual_ptr)) /
205+
iree_math_bf16_to_f32(*expected_ptr)) > equality.bf16_threshold) {
206+
*out_index = i;
207+
return false;
208+
}
209+
expected_ptr += expected_stride;
210+
actual_ptr += actual_stride;
211+
}
212+
return true;
213+
}
214+
180215
static bool iree_hal_compare_strided_elements_approximate_absolute(
181216
iree_hal_buffer_equality_t equality, iree_hal_element_type_t element_type,
182217
iree_host_size_t element_count, iree_const_byte_span_t expected_elements,
@@ -198,6 +233,11 @@ static bool iree_hal_compare_strided_elements_approximate_absolute(
198233
equality, element_count, (const double*)expected_elements.data,
199234
expected_stride, (const double*)actual_elements.data, actual_stride,
200235
out_index);
236+
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
237+
return iree_hal_compare_strided_elements_approximate_absolute_bf16(
238+
equality, element_count, (const uint16_t*)expected_elements.data,
239+
expected_stride, (const uint16_t*)actual_elements.data, actual_stride,
240+
out_index);
201241
default:
202242
return iree_hal_compare_strided_elements_exact(
203243
element_type, element_count, expected_elements, expected_stride,
@@ -226,6 +266,11 @@ static bool iree_hal_compare_strided_elements_approximate_relative(
226266
equality, element_count, (const double*)expected_elements.data,
227267
expected_stride, (const double*)actual_elements.data, actual_stride,
228268
out_index);
269+
case IREE_HAL_ELEMENT_TYPE_BFLOAT_16:
270+
return iree_hal_compare_strided_elements_approximate_relative_bf16(
271+
equality, element_count, (const uint16_t*)expected_elements.data,
272+
expected_stride, (const uint16_t*)actual_elements.data, actual_stride,
273+
out_index);
229274
default:
230275
return iree_hal_compare_strided_elements_exact(
231276
element_type, element_count, expected_elements, expected_stride,

runtime/src/iree/tooling/buffer_view_matchers.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ typedef struct {
5353
float f16_threshold;
5454
float f32_threshold;
5555
double f64_threshold;
56+
float bf16_threshold;
5657
} iree_hal_buffer_equality_t;
5758

5859
// Variant type storing known HAL buffer elements.
@@ -125,6 +126,14 @@ static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_f64(
125126
return element;
126127
}
127128

129+
static inline iree_hal_buffer_element_t iree_hal_make_buffer_element_bf16(
130+
float value) {
131+
iree_hal_buffer_element_t element;
132+
element.type = IREE_HAL_ELEMENT_TYPE_BFLOAT_16;
133+
element.i16 = iree_math_f32_to_bf16(value);
134+
return element;
135+
}
136+
128137
// Returns true if all elements match the uniform value based on |equality|.
129138
// |out_index| will contain the first index that does not match.
130139
bool iree_hal_compare_buffer_elements_broadcast(

runtime/src/iree/tooling/buffer_view_matchers_test.cc

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ static const iree_hal_buffer_equality_t kApproximateAbsoluteEquality = ([]() {
166166
equality.f16_threshold = 0.001f;
167167
equality.f32_threshold = 0.0001f;
168168
equality.f64_threshold = 0.0001;
169+
equality.bf16_threshold = 0.01f;
169170
return equality;
170171
})();
171172

@@ -175,6 +176,7 @@ static const iree_hal_buffer_equality_t kApproximateRelativeEquality = ([]() {
175176
equality.f16_threshold = 0.001f;
176177
equality.f32_threshold = 0.0001f;
177178
equality.f64_threshold = 0.0001;
179+
equality.bf16_threshold = 0.01f;
178180
return equality;
179181
})();
180182

@@ -401,6 +403,64 @@ TEST_F(BufferViewMatchersTest, CompareBroadcastF64NERelative) {
401403
EXPECT_EQ(index, 1);
402404
}
403405

406+
TEST_F(BufferViewMatchersTest, CompareBroadcastBF16EQ) {
407+
const float lhs = 1.0f;
408+
const uint16_t rhs[] = {
409+
iree_math_f32_to_bf16(1.0f),
410+
iree_math_f32_to_bf16(1.0f),
411+
iree_math_f32_to_bf16(1.0f),
412+
};
413+
iree_host_size_t index = 0;
414+
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
415+
kApproximateAbsoluteEquality, iree_hal_make_buffer_element_bf16(lhs),
416+
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
417+
&index));
418+
}
419+
420+
TEST_F(BufferViewMatchersTest, CompareBroadcastBF16EQRelative) {
421+
const float lhs = 1.0f;
422+
const uint16_t rhs[] = {
423+
iree_math_f32_to_bf16(1.0f),
424+
iree_math_f32_to_bf16(1.0f),
425+
iree_math_f32_to_bf16(1.0f),
426+
};
427+
iree_host_size_t index = 0;
428+
EXPECT_TRUE(iree_hal_compare_buffer_elements_broadcast(
429+
kApproximateRelativeEquality, iree_hal_make_buffer_element_bf16(lhs),
430+
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
431+
&index));
432+
}
433+
434+
TEST_F(BufferViewMatchersTest, CompareBroadcastBF16NE) {
435+
const float lhs = 1.0f;
436+
const uint16_t rhs[] = {
437+
iree_math_f32_to_bf16(1.0f),
438+
iree_math_f32_to_bf16(3.0f),
439+
iree_math_f32_to_bf16(4.0f),
440+
};
441+
iree_host_size_t index = 0;
442+
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
443+
kApproximateAbsoluteEquality, iree_hal_make_buffer_element_bf16(lhs),
444+
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
445+
&index));
446+
EXPECT_EQ(index, 1);
447+
}
448+
449+
TEST_F(BufferViewMatchersTest, CompareBroadcastBF16NERelative) {
450+
const float lhs = 1.0f;
451+
const uint16_t rhs[] = {
452+
iree_math_f32_to_bf16(1.0f),
453+
iree_math_f32_to_bf16(3.0f),
454+
iree_math_f32_to_bf16(4.0f),
455+
};
456+
iree_host_size_t index = 0;
457+
EXPECT_FALSE(iree_hal_compare_buffer_elements_broadcast(
458+
kApproximateRelativeEquality, iree_hal_make_buffer_element_bf16(lhs),
459+
IREE_ARRAYSIZE(rhs), iree_make_const_byte_span(rhs, sizeof(rhs)),
460+
&index));
461+
EXPECT_EQ(index, 1);
462+
}
463+
404464
TEST_F(BufferViewMatchersTest, CompareElementwiseF16EQ) {
405465
const uint16_t lhs[] = {
406466
iree_math_f32_to_f16(1.0f),
@@ -520,6 +580,83 @@ TEST_F(BufferViewMatchersTest, CompareElementwiseF64NE) {
520580
EXPECT_EQ(index, 1);
521581
}
522582

583+
TEST_F(BufferViewMatchersTest, CompareElementwiseBF16EQ) {
584+
const uint16_t lhs[] = {
585+
iree_math_f32_to_bf16(1.0f),
586+
iree_math_f32_to_bf16(2.0f),
587+
iree_math_f32_to_bf16(3.0f),
588+
};
589+
const uint16_t rhs[] = {
590+
iree_math_f32_to_bf16(1.0f),
591+
iree_math_f32_to_bf16(2.0f),
592+
iree_math_f32_to_bf16(3.0f),
593+
};
594+
iree_host_size_t index = 0;
595+
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
596+
kApproximateAbsoluteEquality, IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
597+
IREE_ARRAYSIZE(lhs), iree_make_const_byte_span(lhs, sizeof(lhs)),
598+
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
599+
}
600+
601+
TEST_F(BufferViewMatchersTest, CompareElementwiseBF16NearEQ) {
602+
const uint16_t lhs[] = {
603+
iree_math_f32_to_bf16(1.0f),
604+
iree_math_f32_to_bf16(1.99999f),
605+
iree_math_f32_to_bf16(0.00001f),
606+
iree_math_f32_to_bf16(4.0f),
607+
};
608+
const uint16_t rhs[] = {
609+
iree_math_f32_to_bf16(1.00001f),
610+
iree_math_f32_to_bf16(2.0f),
611+
iree_math_f32_to_bf16(0.0f),
612+
iree_math_f32_to_bf16(4.0f),
613+
};
614+
iree_host_size_t index = 0;
615+
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
616+
kApproximateAbsoluteEquality, IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
617+
IREE_ARRAYSIZE(lhs), iree_make_const_byte_span(lhs, sizeof(lhs)),
618+
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
619+
}
620+
621+
TEST_F(BufferViewMatchersTest, CompareElementwiseBF16NearEQRelative) {
622+
const uint16_t lhs[] = {
623+
iree_math_f32_to_bf16(100.0f),
624+
iree_math_f32_to_bf16(19.99f),
625+
iree_math_f32_to_bf16(1.00001f),
626+
iree_math_f32_to_bf16(4.0f),
627+
};
628+
const uint16_t rhs[] = {
629+
iree_math_f32_to_bf16(100.01f),
630+
iree_math_f32_to_bf16(20.00f),
631+
iree_math_f32_to_bf16(1.0f),
632+
iree_math_f32_to_bf16(4.0f),
633+
};
634+
iree_host_size_t index = 0;
635+
EXPECT_TRUE(iree_hal_compare_buffer_elements_elementwise(
636+
kApproximateRelativeEquality, IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
637+
IREE_ARRAYSIZE(lhs), iree_make_const_byte_span(lhs, sizeof(lhs)),
638+
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
639+
}
640+
641+
TEST_F(BufferViewMatchersTest, CompareElementwiseBF16NE) {
642+
const uint16_t lhs[] = {
643+
iree_math_f32_to_bf16(1.0f),
644+
iree_math_f32_to_bf16(2.0f),
645+
iree_math_f32_to_bf16(4.0f),
646+
};
647+
const uint16_t rhs[] = {
648+
iree_math_f32_to_bf16(1.0f),
649+
iree_math_f32_to_bf16(3.0f),
650+
iree_math_f32_to_bf16(4.0f),
651+
};
652+
iree_host_size_t index = 0;
653+
EXPECT_FALSE(iree_hal_compare_buffer_elements_elementwise(
654+
kApproximateAbsoluteEquality, IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
655+
IREE_ARRAYSIZE(lhs), iree_make_const_byte_span(lhs, sizeof(lhs)),
656+
iree_make_const_byte_span(rhs, sizeof(rhs)), &index));
657+
EXPECT_EQ(index, 1);
658+
}
659+
523660
//===----------------------------------------------------------------------===//
524661
// iree_hal_buffer_view_metadata_matcher_t
525662
//===----------------------------------------------------------------------===//
@@ -804,5 +941,41 @@ TEST_F(BufferViewMatchersTest, MismatchContentsF16) {
804941
EXPECT_THAT(sb.ToString(), HasSubstr("element at index 0"));
805942
}
806943

944+
TEST_F(BufferViewMatchersTest, MatchContentsBF16) {
945+
const uint16_t lhs_contents[] = {iree_math_f32_to_bf16(2.0f)};
946+
const uint16_t rhs_contents[] = {iree_math_f32_to_bf16(2.0f)};
947+
iree_hal_dim_t shape[] = {1};
948+
IREE_ASSERT_OK_AND_ASSIGN(
949+
auto lhs,
950+
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, lhs_contents));
951+
IREE_ASSERT_OK_AND_ASSIGN(
952+
auto rhs,
953+
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, rhs_contents));
954+
auto sb = StringBuilder::MakeSystem();
955+
bool match = false;
956+
IREE_ASSERT_OK(
957+
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
958+
EXPECT_TRUE(match);
959+
EXPECT_TRUE(sb.ToString().empty());
960+
}
961+
962+
TEST_F(BufferViewMatchersTest, MismatchContentsBF16) {
963+
const uint16_t lhs_contents[] = {iree_math_f32_to_bf16(1.0f)};
964+
const uint16_t rhs_contents[] = {iree_math_f32_to_bf16(2.0f)};
965+
const iree_hal_dim_t shape[] = {1};
966+
IREE_ASSERT_OK_AND_ASSIGN(
967+
auto lhs,
968+
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, lhs_contents));
969+
IREE_ASSERT_OK_AND_ASSIGN(
970+
auto rhs,
971+
CreateBufferView(shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, rhs_contents));
972+
auto sb = StringBuilder::MakeSystem();
973+
bool match = false;
974+
IREE_ASSERT_OK(
975+
iree_hal_buffer_view_match_equal(kExactEquality, lhs, rhs, sb, &match));
976+
EXPECT_FALSE(match);
977+
EXPECT_THAT(sb.ToString(), HasSubstr("element at index 0"));
978+
}
979+
807980
} // namespace
808981
} // namespace iree

runtime/src/iree/tooling/comparison.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ IREE_FLAG(float, expected_f32_threshold, 0.0001f,
2424
"Threshold under which two f32 values are considered equal.");
2525
IREE_FLAG(double, expected_f64_threshold, 0.0001,
2626
"Threshold under which two f64 values are considered equal.");
27+
IREE_FLAG(float, expected_bf16_threshold, 0.01f,
28+
"Threshold under which two bf16 values are considered equal.");
2729
IREE_FLAG(string, equality_mode, "absolute",
2830
"Choose the type of comparison desired between buffers from "
2931
"[`absolute`, `relative`, `exact`].");
@@ -42,6 +44,7 @@ static iree_hal_buffer_equality_t iree_tooling_equality_from_flags(void) {
4244
equality.f16_threshold = FLAG_expected_f16_threshold;
4345
equality.f32_threshold = FLAG_expected_f32_threshold;
4446
equality.f64_threshold = FLAG_expected_f64_threshold;
47+
equality.bf16_threshold = FLAG_expected_bf16_threshold;
4548
return equality;
4649
}
4750

tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O0.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,6 @@
310310
"onnx/node/generated/test_bernoulli_double",
311311
"onnx/node/generated/test_bernoulli_double_expanded",
312312
"onnx/node/generated/test_bernoulli_expanded",
313-
"onnx/node/generated/test_cast_FLOAT_to_BFLOAT16",
314-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16",
315-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16_expanded",
316313
"onnx/node/generated/test_convtranspose_output_shape",
317314
"onnx/node/generated/test_gridsample_nearest",
318315
"onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1",

tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync_O2.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@
312312
"onnx/node/generated/test_bernoulli_double",
313313
"onnx/node/generated/test_bernoulli_double_expanded",
314314
"onnx/node/generated/test_bernoulli_expanded",
315-
"onnx/node/generated/test_cast_FLOAT_to_BFLOAT16",
316-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16",
317-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16_expanded",
318315
"onnx/node/generated/test_convtranspose_output_shape",
319316
"onnx/node/generated/test_gridsample_nearest",
320317
"onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1",

tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna3_O3.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@
308308
"onnx/node/generated/test_bernoulli_double",
309309
"onnx/node/generated/test_bernoulli_double_expanded",
310310
"onnx/node/generated/test_bernoulli_expanded",
311-
"onnx/node/generated/test_cast_FLOAT_to_BFLOAT16",
312-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16",
313-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16_expanded",
314311
"onnx/node/generated/test_convtranspose_output_shape",
315312
"onnx/node/generated/test_gridsample_nearest",
316313
"onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1",

tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_hip_rdna4_O3.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@
308308
"onnx/node/generated/test_bernoulli_double",
309309
"onnx/node/generated/test_bernoulli_double_expanded",
310310
"onnx/node/generated/test_bernoulli_expanded",
311-
"onnx/node/generated/test_cast_FLOAT_to_BFLOAT16",
312-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16",
313-
"onnx/node/generated/test_castlike_FLOAT_to_BFLOAT16_expanded",
314311
"onnx/node/generated/test_convtranspose_output_shape",
315312
"onnx/node/generated/test_gridsample_nearest",
316313
"onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1",

tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan_O0.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,6 @@
348348
"onnx/node/generated/test_cast_DOUBLE_to_FLOAT16",
349349
"onnx/node/generated/test_cast_FLOAT16_to_DOUBLE",
350350
"onnx/node/generated/test_cast_FLOAT16_to_FLOAT",
351-
"onnx/node/generated/test_cast_FLOAT_to_BFLOAT16",
352351
"onnx/node/generated/test_cast_FLOAT_to_DOUBLE",
353352
"onnx/node/generated/test_cast_FLOAT_to_FLOAT16",
354353
"onnx/node/generated/test_castlike_DOUBLE_to_FLOAT",

0 commit comments

Comments
 (0)