Skip to content

Commit

Permalink
Support Half/BFloat16 in amax/amin (#7767)
Browse files Browse the repository at this point in the history
Partial fix for #7748.
  • Loading branch information
swolchok authored and YIWENX14 committed Jan 28, 2025
1 parent 887310e commit 0c2864a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 74 deletions.
25 changes: 12 additions & 13 deletions kernels/portable/cpu/op_amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,18 @@ Tensor& amax_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(
Bool, in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
[](CTYPE v, CTYPE max_v) {
return std::isnan(v) || v > max_v ? v : max_v;
},
in,
dim_list,
out_ix);
}
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
[](CTYPE v, CTYPE max_v) {
return std::isnan(v) || v > max_v ? v : max_v;
},
in,
dim_list,
out_ix);
}
});

return out;
}
Expand Down
25 changes: 12 additions & 13 deletions kernels/portable/cpu/op_amin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,18 @@ Tensor& amin_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(
Bool, in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
[](CTYPE v, CTYPE min_v) {
return std::isnan(v) || v < min_v ? v : min_v;
},
in,
dim_list,
out_ix);
}
});
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
[](CTYPE v, CTYPE min_v) {
return std::isnan(v) || v < min_v ? v : min_v;
},
in,
dim_list,
out_ix);
}
});

return out;
}
Expand Down
60 changes: 36 additions & 24 deletions kernels/test/op_amax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,37 @@ class OpAmaxOutTest : public OperatorTest {
op_amax_out(in, empty_dim_list, /*keepdim=*/false, out);
EXPECT_TENSOR_CLOSE(out, tf.make({}, {9}));
}

template <ScalarType DTYPE>
void test_amax_out_infinity_nan() {
TensorFactory<DTYPE> tf_dtype;
using CTYPE = typename decltype(tf_dtype)::ctype;
const auto infinity = std::numeric_limits<CTYPE>::infinity();
const auto nan = std::numeric_limits<CTYPE>::quiet_NaN();

// clang-format off
Tensor in = tf_dtype.make(
{2, 3, 4},
{
0, 1, 2, infinity,
infinity, -infinity, 1, 0,
nan, infinity, -infinity, 2,

nan, nan, 1, 0,
0, infinity, nan, 4,
1, nan, 3.14, 2,
});
// clang-format on

Tensor out = tf_dtype.zeros({2, 3, 1});
int64_t dims[1] = {-1};
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
op_amax_out(in, dim_list, /*keepdim=*/true, out);
// clang-format off
EXPECT_TENSOR_CLOSE(
out, tf_dtype.make({2, 3, 1}, {infinity, infinity, nan, nan, nan, nan}));
// clang-format on
}
};

template <>
Expand Down Expand Up @@ -280,32 +311,13 @@ TEST_F(OpAmaxOutTest, MismatchedDTypesDies) {

TEST_F(OpAmaxOutTest, AllRealInputOutputPasses) {
#define TEST_ENTRY(ctype, dtype) test_amax_out_dtype<ScalarType::dtype>();
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpAmaxOutTest, InfinityAndNANTest) {
TensorFactory<ScalarType::Float> tf_float;
// clang-format off
Tensor in = tf_float.make(
{2, 3, 4},
{
0, 1, 2, INFINITY,
INFINITY, -INFINITY, 1, 0,
NAN, INFINITY, -INFINITY, 2,

NAN, NAN, 1, 0,
0, INFINITY, NAN, 4,
1, NAN, 3.14, 2,
});
// clang-format on

Tensor out = tf_float.zeros({2, 3, 1});
int64_t dims[1] = {-1};
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
op_amax_out(in, dim_list, /*keepdim=*/true, out);
// clang-format off
EXPECT_TENSOR_CLOSE(
out, tf_float.make({2, 3, 1}, {INFINITY, INFINITY, NAN, NAN, NAN, NAN}));
// clang-format on
#define TEST_ENTRY(ctype, dtype) \
test_amax_out_infinity_nan<ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}
59 changes: 35 additions & 24 deletions kernels/test/op_amin_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,36 @@ class OpAminOutTest : public OperatorTest {
op_amin_out(in, empty_dim_list, /*keepdim=*/false, out);
EXPECT_TENSOR_CLOSE(out, tf.make({}, {2}));
}

template <ScalarType DTYPE>
void test_amin_out_infinity_nan() {
TensorFactory<DTYPE> tf_dtype;
using CTYPE = typename decltype(tf_dtype)::ctype;
const auto infinity = std::numeric_limits<CTYPE>::infinity();
const auto nan = std::numeric_limits<CTYPE>::quiet_NaN();
// clang-format off
Tensor in = tf_dtype.make(
{2, 3, 4},
{
0, 1, 2, infinity,
infinity, -infinity, 1, 0,
nan, infinity, -infinity, 2,

nan, nan, 1, 0,
0, infinity, nan, 4,
1, nan, 3.14, 2,
});
// clang-format on

Tensor out = tf_dtype.zeros({2, 3, 1});
int64_t dims[1] = {-1};
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
op_amin_out(in, dim_list, /*keepdim=*/true, out);
// clang-format off
EXPECT_TENSOR_CLOSE(
out, tf_dtype.make({2, 3, 1}, {0, -infinity, nan, nan, nan, nan}));
// clang-format on
}
};

template <>
Expand Down Expand Up @@ -280,32 +310,13 @@ TEST_F(OpAminOutTest, MismatchedDTypesDies) {

TEST_F(OpAminOutTest, AllRealInputOutputPasses) {
#define TEST_ENTRY(ctype, dtype) test_amin_out_dtype<ScalarType::dtype>();
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpAminOutTest, InfinityAndNANTest) {
TensorFactory<ScalarType::Float> tf_float;
// clang-format off
Tensor in = tf_float.make(
{2, 3, 4},
{
0, 1, 2, INFINITY,
INFINITY, -INFINITY, 1, 0,
NAN, INFINITY, -INFINITY, 2,

NAN, NAN, 1, 0,
0, INFINITY, NAN, 4,
1, NAN, 3.14, 2,
});
// clang-format on

Tensor out = tf_float.zeros({2, 3, 1});
int64_t dims[1] = {-1};
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
op_amin_out(in, dim_list, /*keepdim=*/true, out);
// clang-format off
EXPECT_TENSOR_CLOSE(
out, tf_float.make({2, 3, 1}, {0, -INFINITY, NAN, NAN, NAN, NAN}));
// clang-format on
#define TEST_ENTRY(ctype, dtype) \
test_amin_out_infinity_nan<ScalarType::dtype>();
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

0 comments on commit 0c2864a

Please sign in to comment.