From 86fca5a69e987f1cae264f7ccc683b32629d6561 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 17 Jan 2025 18:40:09 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_any.cpp | 6 +++--- kernels/test/op_any_test.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index e2bd8ebe199..2168859582b 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -29,7 +29,7 @@ Tensor& any_all_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ScalarType out_type = out.scalar_type(); constexpr auto name = "any.all_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { const auto data_in = in.const_data_ptr(); auto data_out = out.mutable_data_ptr(); @@ -78,7 +78,7 @@ Tensor& any_dims_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "any.dims_out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); if (dim_list.has_value() && dim_list.value().empty()) { @@ -135,7 +135,7 @@ Tensor& any_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "any.out"; - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] { ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] { CTYPE_OUT* out_data = out.mutable_data_ptr(); for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { diff --git a/kernels/test/op_any_test.cpp b/kernels/test/op_any_test.cpp index 09f9cdd4991..400602bd2f6 100644 --- a/kernels/test/op_any_test.cpp +++ b/kernels/test/op_any_test.cpp @@ -120,7 +120,7 @@ TEST_F(OpAnyOutTest, InvalidDtypeDies) { TEST_F(OpAnyOutTest, AllRealInputTypePasses) { #define TEST_ENTRY(ctype, dtype) test_any_all_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }