@@ -166,6 +166,7 @@ static const iree_hal_buffer_equality_t kApproximateAbsoluteEquality = ([]() {
166166 equality.f16_threshold = 0 .001f ;
167167 equality.f32_threshold = 0 .0001f ;
168168 equality.f64_threshold = 0.0001 ;
169+ equality.bf16_threshold = 0 .01f ;
169170 return equality;
170171})();
171172
@@ -175,6 +176,7 @@ static const iree_hal_buffer_equality_t kApproximateRelativeEquality = ([]() {
175176 equality.f16_threshold = 0 .001f ;
176177 equality.f32_threshold = 0 .0001f ;
177178 equality.f64_threshold = 0.0001 ;
179+ equality.bf16_threshold = 0 .01f ;
178180 return equality;
179181})();
180182
@@ -401,6 +403,64 @@ TEST_F(BufferViewMatchersTest, CompareBroadcastF64NERelative) {
401403 EXPECT_EQ (index, 1 );
402404}
403405
406+ TEST_F (BufferViewMatchersTest, CompareBroadcastBF16EQ) {
407+ const float lhs = 1 .0f ;
408+ const uint16_t rhs[] = {
409+ iree_math_f32_to_bf16 (1 .0f ),
410+ iree_math_f32_to_bf16 (1 .0f ),
411+ iree_math_f32_to_bf16 (1 .0f ),
412+ };
413+ iree_host_size_t index = 0 ;
414+ EXPECT_TRUE (iree_hal_compare_buffer_elements_broadcast (
415+ kApproximateAbsoluteEquality , iree_hal_make_buffer_element_bf16 (lhs),
416+ IREE_ARRAYSIZE (rhs), iree_make_const_byte_span (rhs, sizeof (rhs)),
417+ &index));
418+ }
419+
420+ TEST_F (BufferViewMatchersTest, CompareBroadcastBF16EQRelative) {
421+ const float lhs = 1 .0f ;
422+ const uint16_t rhs[] = {
423+ iree_math_f32_to_bf16 (1 .0f ),
424+ iree_math_f32_to_bf16 (1 .0f ),
425+ iree_math_f32_to_bf16 (1 .0f ),
426+ };
427+ iree_host_size_t index = 0 ;
428+ EXPECT_TRUE (iree_hal_compare_buffer_elements_broadcast (
429+ kApproximateRelativeEquality , iree_hal_make_buffer_element_bf16 (lhs),
430+ IREE_ARRAYSIZE (rhs), iree_make_const_byte_span (rhs, sizeof (rhs)),
431+ &index));
432+ }
433+
434+ TEST_F (BufferViewMatchersTest, CompareBroadcastBF16NE) {
435+ const float lhs = 1 .0f ;
436+ const uint16_t rhs[] = {
437+ iree_math_f32_to_bf16 (1 .0f ),
438+ iree_math_f32_to_bf16 (3 .0f ),
439+ iree_math_f32_to_bf16 (4 .0f ),
440+ };
441+ iree_host_size_t index = 0 ;
442+ EXPECT_FALSE (iree_hal_compare_buffer_elements_broadcast (
443+ kApproximateAbsoluteEquality , iree_hal_make_buffer_element_bf16 (lhs),
444+ IREE_ARRAYSIZE (rhs), iree_make_const_byte_span (rhs, sizeof (rhs)),
445+ &index));
446+ EXPECT_EQ (index, 1 );
447+ }
448+
449+ TEST_F (BufferViewMatchersTest, CompareBroadcastBF16NERelative) {
450+ const float lhs = 1 .0f ;
451+ const uint16_t rhs[] = {
452+ iree_math_f32_to_bf16 (1 .0f ),
453+ iree_math_f32_to_bf16 (3 .0f ),
454+ iree_math_f32_to_bf16 (4 .0f ),
455+ };
456+ iree_host_size_t index = 0 ;
457+ EXPECT_FALSE (iree_hal_compare_buffer_elements_broadcast (
458+ kApproximateRelativeEquality , iree_hal_make_buffer_element_bf16 (lhs),
459+ IREE_ARRAYSIZE (rhs), iree_make_const_byte_span (rhs, sizeof (rhs)),
460+ &index));
461+ EXPECT_EQ (index, 1 );
462+ }
463+
404464TEST_F (BufferViewMatchersTest, CompareElementwiseF16EQ) {
405465 const uint16_t lhs[] = {
406466 iree_math_f32_to_f16 (1 .0f ),
@@ -520,6 +580,83 @@ TEST_F(BufferViewMatchersTest, CompareElementwiseF64NE) {
520580 EXPECT_EQ (index, 1 );
521581}
522582
583+ TEST_F (BufferViewMatchersTest, CompareElementwiseBF16EQ) {
584+ const uint16_t lhs[] = {
585+ iree_math_f32_to_bf16 (1 .0f ),
586+ iree_math_f32_to_bf16 (2 .0f ),
587+ iree_math_f32_to_bf16 (3 .0f ),
588+ };
589+ const uint16_t rhs[] = {
590+ iree_math_f32_to_bf16 (1 .0f ),
591+ iree_math_f32_to_bf16 (2 .0f ),
592+ iree_math_f32_to_bf16 (3 .0f ),
593+ };
594+ iree_host_size_t index = 0 ;
595+ EXPECT_TRUE (iree_hal_compare_buffer_elements_elementwise (
596+ kApproximateAbsoluteEquality , IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
597+ IREE_ARRAYSIZE (lhs), iree_make_const_byte_span (lhs, sizeof (lhs)),
598+ iree_make_const_byte_span (rhs, sizeof (rhs)), &index));
599+ }
600+
601+ TEST_F (BufferViewMatchersTest, CompareElementwiseBF16NearEQ) {
602+ const uint16_t lhs[] = {
603+ iree_math_f32_to_bf16 (1 .0f ),
604+ iree_math_f32_to_bf16 (1 .99999f ),
605+ iree_math_f32_to_bf16 (0 .00001f ),
606+ iree_math_f32_to_bf16 (4 .0f ),
607+ };
608+ const uint16_t rhs[] = {
609+ iree_math_f32_to_bf16 (1 .00001f ),
610+ iree_math_f32_to_bf16 (2 .0f ),
611+ iree_math_f32_to_bf16 (0 .0f ),
612+ iree_math_f32_to_bf16 (4 .0f ),
613+ };
614+ iree_host_size_t index = 0 ;
615+ EXPECT_TRUE (iree_hal_compare_buffer_elements_elementwise (
616+ kApproximateAbsoluteEquality , IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
617+ IREE_ARRAYSIZE (lhs), iree_make_const_byte_span (lhs, sizeof (lhs)),
618+ iree_make_const_byte_span (rhs, sizeof (rhs)), &index));
619+ }
620+
621+ TEST_F (BufferViewMatchersTest, CompareElementwiseBF16NearEQRelative) {
622+ const uint16_t lhs[] = {
623+ iree_math_f32_to_bf16 (100 .0f ),
624+ iree_math_f32_to_bf16 (19 .99f ),
625+ iree_math_f32_to_bf16 (1 .00001f ),
626+ iree_math_f32_to_bf16 (4 .0f ),
627+ };
628+ const uint16_t rhs[] = {
629+ iree_math_f32_to_bf16 (100 .01f ),
630+ iree_math_f32_to_bf16 (20 .00f ),
631+ iree_math_f32_to_bf16 (1 .0f ),
632+ iree_math_f32_to_bf16 (4 .0f ),
633+ };
634+ iree_host_size_t index = 0 ;
635+ EXPECT_TRUE (iree_hal_compare_buffer_elements_elementwise (
636+ kApproximateRelativeEquality , IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
637+ IREE_ARRAYSIZE (lhs), iree_make_const_byte_span (lhs, sizeof (lhs)),
638+ iree_make_const_byte_span (rhs, sizeof (rhs)), &index));
639+ }
640+
641+ TEST_F (BufferViewMatchersTest, CompareElementwiseBF16NE) {
642+ const uint16_t lhs[] = {
643+ iree_math_f32_to_bf16 (1 .0f ),
644+ iree_math_f32_to_bf16 (2 .0f ),
645+ iree_math_f32_to_bf16 (4 .0f ),
646+ };
647+ const uint16_t rhs[] = {
648+ iree_math_f32_to_bf16 (1 .0f ),
649+ iree_math_f32_to_bf16 (3 .0f ),
650+ iree_math_f32_to_bf16 (4 .0f ),
651+ };
652+ iree_host_size_t index = 0 ;
653+ EXPECT_FALSE (iree_hal_compare_buffer_elements_elementwise (
654+ kApproximateAbsoluteEquality , IREE_HAL_ELEMENT_TYPE_BFLOAT_16,
655+ IREE_ARRAYSIZE (lhs), iree_make_const_byte_span (lhs, sizeof (lhs)),
656+ iree_make_const_byte_span (rhs, sizeof (rhs)), &index));
657+ EXPECT_EQ (index, 1 );
658+ }
659+
523660// ===----------------------------------------------------------------------===//
524661// iree_hal_buffer_view_metadata_matcher_t
525662// ===----------------------------------------------------------------------===//
@@ -804,5 +941,41 @@ TEST_F(BufferViewMatchersTest, MismatchContentsF16) {
804941 EXPECT_THAT (sb.ToString (), HasSubstr (" element at index 0" ));
805942}
806943
944+ TEST_F (BufferViewMatchersTest, MatchContentsBF16) {
945+ const uint16_t lhs_contents[] = {iree_math_f32_to_bf16 (2 .0f )};
946+ const uint16_t rhs_contents[] = {iree_math_f32_to_bf16 (2 .0f )};
947+ iree_hal_dim_t shape[] = {1 };
948+ IREE_ASSERT_OK_AND_ASSIGN (
949+ auto lhs,
950+ CreateBufferView (shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, lhs_contents));
951+ IREE_ASSERT_OK_AND_ASSIGN (
952+ auto rhs,
953+ CreateBufferView (shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, rhs_contents));
954+ auto sb = StringBuilder::MakeSystem ();
955+ bool match = false ;
956+ IREE_ASSERT_OK (
957+ iree_hal_buffer_view_match_equal (kExactEquality , lhs, rhs, sb, &match));
958+ EXPECT_TRUE (match);
959+ EXPECT_TRUE (sb.ToString ().empty ());
960+ }
961+
962+ TEST_F (BufferViewMatchersTest, MismatchContentsBF16) {
963+ const uint16_t lhs_contents[] = {iree_math_f32_to_bf16 (1 .0f )};
964+ const uint16_t rhs_contents[] = {iree_math_f32_to_bf16 (2 .0f )};
965+ const iree_hal_dim_t shape[] = {1 };
966+ IREE_ASSERT_OK_AND_ASSIGN (
967+ auto lhs,
968+ CreateBufferView (shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, lhs_contents));
969+ IREE_ASSERT_OK_AND_ASSIGN (
970+ auto rhs,
971+ CreateBufferView (shape, IREE_HAL_ELEMENT_TYPE_BFLOAT_16, rhs_contents));
972+ auto sb = StringBuilder::MakeSystem ();
973+ bool match = false ;
974+ IREE_ASSERT_OK (
975+ iree_hal_buffer_view_match_equal (kExactEquality , lhs, rhs, sb, &match));
976+ EXPECT_FALSE (match);
977+ EXPECT_THAT (sb.ToString (), HasSubstr (" element at index 0" ));
978+ }
979+
807980} // namespace
808981} // namespace iree
0 commit comments