Skip to content

Commit 99489fe

Browse files
authored
Support Half/BFloat16 in op_allclose
Differential Revision: D68366831 Pull Request resolved: #7766
1 parent 43580f5 commit 99489fe

File tree

2 files changed

+88
-220
lines changed

2 files changed

+88
-220
lines changed

kernels/portable/cpu/op_allclose.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ bool tensors_are_close(
8484
a.numel(),
8585
rtol,
8686
atol);
87+
} else if (a.scalar_type() == ScalarType::Half) {
88+
return data_is_close<Half>(
89+
a.const_data_ptr<Half>(),
90+
b.const_data_ptr<Half>(),
91+
a.numel(),
92+
rtol,
93+
atol);
94+
} else if (a.scalar_type() == ScalarType::BFloat16) {
95+
return data_is_close<BFloat16>(
96+
a.const_data_ptr<BFloat16>(),
97+
b.const_data_ptr<BFloat16>(),
98+
a.numel(),
99+
rtol,
100+
atol);
87101
} else {
88102
// Non-floating-point types can be compared bitwise.
89103
return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;

0 commit comments

Comments
 (0)