diff --git a/kernels/portable/cpu/op_amax.cpp b/kernels/portable/cpu/op_amax.cpp index 088c30a375..519aa8ac92 100644 --- a/kernels/portable/cpu/op_amax.cpp +++ b/kernels/portable/cpu/op_amax.cpp @@ -42,19 +42,18 @@ Tensor& amax_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "amax.out", CTYPE, [&]() { - CTYPE* out_data = out.mutable_data_ptr(); - for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { - out_data[out_ix] = reduce_over_dim_list( - [](CTYPE v, CTYPE max_v) { - return std::isnan(v) || v > max_v ? v : max_v; - }, - in, - dim_list, - out_ix); - } - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() { + CTYPE* out_data = out.mutable_data_ptr(); + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + out_data[out_ix] = reduce_over_dim_list( + [](CTYPE v, CTYPE max_v) { + return std::isnan(v) || v > max_v ? v : max_v; + }, + in, + dim_list, + out_ix); + } + }); return out; } diff --git a/kernels/portable/cpu/op_amin.cpp b/kernels/portable/cpu/op_amin.cpp index 9f2aa38c89..0f77980a50 100644 --- a/kernels/portable/cpu/op_amin.cpp +++ b/kernels/portable/cpu/op_amin.cpp @@ -42,19 +42,18 @@ Tensor& amin_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "amin.out", CTYPE, [&]() { - CTYPE* out_data = out.mutable_data_ptr(); - for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { - out_data[out_ix] = reduce_over_dim_list( - [](CTYPE v, CTYPE min_v) { - return std::isnan(v) || v < min_v ? v : min_v; - }, - in, - dim_list, - out_ix); - } - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() { + CTYPE* out_data = out.mutable_data_ptr(); + for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) { + out_data[out_ix] = reduce_over_dim_list( + [](CTYPE v, CTYPE min_v) { + return std::isnan(v) || v < min_v ? v : min_v; + }, + in, + dim_list, + out_ix); + } + }); return out; } diff --git a/kernels/test/op_amax_test.cpp b/kernels/test/op_amax_test.cpp index ba0ac0527b..a7794cc24c 100644 --- a/kernels/test/op_amax_test.cpp +++ b/kernels/test/op_amax_test.cpp @@ -189,6 +189,37 @@ class OpAmaxOutTest : public OperatorTest { op_amax_out(in, empty_dim_list, /*keepdim=*/false, out); EXPECT_TENSOR_CLOSE(out, tf.make({}, {9})); } + + template + void test_amax_out_infinity_nan() { + TensorFactory tf_dtype; + using CTYPE = typename decltype(tf_dtype)::ctype; + const auto infinity = std::numeric_limits::infinity(); + const auto nan = std::numeric_limits::quiet_NaN(); + + // clang-format off + Tensor in = tf_dtype.make( + {2, 3, 4}, + { + 0, 1, 2, infinity, + infinity, -infinity, 1, 0, + nan, infinity, -infinity, 2, + + nan, nan, 1, 0, + 0, infinity, nan, 4, + 1, nan, 3.14, 2, + }); + // clang-format on + + Tensor out = tf_dtype.zeros({2, 3, 1}); + int64_t dims[1] = {-1}; + ArrayRef dim_list{ArrayRef{dims, 1}}; + op_amax_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE( + out, tf_dtype.make({2, 3, 1}, {infinity, infinity, nan, nan, nan, nan})); + // clang-format on + } }; template <> @@ -280,32 +311,13 @@ TEST_F(OpAmaxOutTest, MismatchedDTypesDies) { TEST_F(OpAmaxOutTest, AllRealInputOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_amax_out_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpAmaxOutTest, InfinityAndNANTest) { - TensorFactory tf_float; - // clang-format off - Tensor in = tf_float.make( - {2, 3, 4}, - { - 0, 1, 2, INFINITY, - INFINITY, -INFINITY, 1, 0, - NAN, INFINITY, -INFINITY, 2, - - NAN, NAN, 1, 0, - 0, INFINITY, NAN, 4, - 1, NAN, 3.14, 2, - }); - // clang-format on - - Tensor out = tf_float.zeros({2, 3, 1}); - int64_t dims[1] = {-1}; - ArrayRef dim_list{ArrayRef{dims, 1}}; - op_amax_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE( - out, tf_float.make({2, 3, 1}, {INFINITY, INFINITY, NAN, NAN, NAN, NAN})); - // clang-format on +#define TEST_ENTRY(ctype, dtype) \ + test_amax_out_infinity_nan(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY } diff --git a/kernels/test/op_amin_test.cpp b/kernels/test/op_amin_test.cpp index 4d56e1389d..001bb93ca8 100644 --- a/kernels/test/op_amin_test.cpp +++ b/kernels/test/op_amin_test.cpp @@ -189,6 +189,36 @@ class OpAminOutTest : public OperatorTest { op_amin_out(in, empty_dim_list, /*keepdim=*/false, out); EXPECT_TENSOR_CLOSE(out, tf.make({}, {2})); } + + template + void test_amin_out_infinity_nan() { + TensorFactory tf_dtype; + using CTYPE = typename decltype(tf_dtype)::ctype; + const auto infinity = std::numeric_limits::infinity(); + const auto nan = std::numeric_limits::quiet_NaN(); + // clang-format off + Tensor in = tf_dtype.make( + {2, 3, 4}, + { + 0, 1, 2, infinity, + infinity, -infinity, 1, 0, + nan, infinity, -infinity, 2, + + nan, nan, 1, 0, + 0, infinity, nan, 4, + 1, nan, 3.14, 2, + }); + // clang-format on + + Tensor out = tf_dtype.zeros({2, 3, 1}); + int64_t dims[1] = {-1}; + ArrayRef dim_list{ArrayRef{dims, 1}}; + op_amin_out(in, dim_list, /*keepdim=*/true, out); + // clang-format off + EXPECT_TENSOR_CLOSE( + out, tf_dtype.make({2, 3, 1}, {0, -infinity, nan, nan, nan, nan})); + // clang-format on + } }; template <> @@ -280,32 +310,13 @@ TEST_F(OpAminOutTest, MismatchedDTypesDies) { TEST_F(OpAminOutTest, AllRealInputOutputPasses) { #define TEST_ENTRY(ctype, dtype) test_amin_out_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpAminOutTest, InfinityAndNANTest) { - TensorFactory tf_float; - // clang-format off - Tensor in = tf_float.make( - {2, 3, 4}, - { - 0, 1, 2, INFINITY, - INFINITY, -INFINITY, 1, 0, - NAN, INFINITY, -INFINITY, 2, - - NAN, NAN, 1, 0, - 0, INFINITY, NAN, 4, - 1, NAN, 3.14, 2, - }); - // clang-format on - - Tensor out = tf_float.zeros({2, 3, 1}); - int64_t dims[1] = {-1}; - ArrayRef dim_list{ArrayRef{dims, 1}}; - op_amin_out(in, dim_list, /*keepdim=*/true, out); - // clang-format off - EXPECT_TENSOR_CLOSE( - out, tf_float.make({2, 3, 1}, {0, -INFINITY, NAN, NAN, NAN, NAN})); - // clang-format on +#define TEST_ENTRY(ctype, dtype) \ + test_amin_out_infinity_nan(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY }