Skip to content

Commit 43330ff

Browse files
pytorchbotswolchok
authored andcommitted
Support Half/BFloat16 in op_allclose
Pull Request resolved: #7766 We incorrectly required these types to be bitwise-identical rather than close. (I had to develop this internally because the op_allclose_test doesn't run in OSS.) Differential Revision: [D68366831](https://our.internmc.facebook.com/intern/diff/D68366831/) ghstack-source-id: 262600586 Co-authored-by: Scott Wolchok <[email protected]>
1 parent de15762 commit 43330ff

File tree

2 files changed

+88
-220
lines changed

2 files changed

+88
-220
lines changed

kernels/portable/cpu/op_allclose.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ bool tensors_are_close(
8484
a.numel(),
8585
rtol,
8686
atol);
87+
} else if (a.scalar_type() == ScalarType::Half) {
88+
return data_is_close<Half>(
89+
a.const_data_ptr<Half>(),
90+
b.const_data_ptr<Half>(),
91+
a.numel(),
92+
rtol,
93+
atol);
94+
} else if (a.scalar_type() == ScalarType::BFloat16) {
95+
return data_is_close<BFloat16>(
96+
a.const_data_ptr<BFloat16>(),
97+
b.const_data_ptr<BFloat16>(),
98+
a.numel(),
99+
rtol,
100+
atol);
87101
} else {
88102
// Non-floating-point types can be compared bitwise.
89103
return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;

0 commit comments

Comments
 (0)