Skip to content

Commit 4bf83b9

Browse files
authored
Support Half/BFloat16 in amax/amin (#7767)
Partial fix for #7748.
1 parent 02da069 commit 4bf83b9

File tree

4 files changed

+95
-74
lines changed

4 files changed

+95
-74
lines changed

kernels/portable/cpu/op_amax.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,18 @@ Tensor& amax_out(
4242
ET_KERNEL_CHECK(
4343
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4444

45-
ET_SWITCH_REAL_TYPES_AND(
46-
Bool, in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
47-
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
48-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
49-
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
50-
[](CTYPE v, CTYPE max_v) {
51-
return std::isnan(v) || v > max_v ? v : max_v;
52-
},
53-
in,
54-
dim_list,
55-
out_ix);
56-
}
57-
});
45+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
46+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
47+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
48+
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
49+
[](CTYPE v, CTYPE max_v) {
50+
return std::isnan(v) || v > max_v ? v : max_v;
51+
},
52+
in,
53+
dim_list,
54+
out_ix);
55+
}
56+
});
5857

5958
return out;
6059
}

kernels/portable/cpu/op_amin.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,18 @@ Tensor& amin_out(
4242
ET_KERNEL_CHECK(
4343
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4444

45-
ET_SWITCH_REAL_TYPES_AND(
46-
Bool, in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
47-
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
48-
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
49-
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
50-
[](CTYPE v, CTYPE min_v) {
51-
return std::isnan(v) || v < min_v ? v : min_v;
52-
},
53-
in,
54-
dim_list,
55-
out_ix);
56-
}
57-
});
45+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
46+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
47+
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
48+
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
49+
[](CTYPE v, CTYPE min_v) {
50+
return std::isnan(v) || v < min_v ? v : min_v;
51+
},
52+
in,
53+
dim_list,
54+
out_ix);
55+
}
56+
});
5857

5958
return out;
6059
}

kernels/test/op_amax_test.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,37 @@ class OpAmaxOutTest : public OperatorTest {
189189
op_amax_out(in, empty_dim_list, /*keepdim=*/false, out);
190190
EXPECT_TENSOR_CLOSE(out, tf.make({}, {9}));
191191
}
192+
193+
template <ScalarType DTYPE>
194+
void test_amax_out_infinity_nan() {
195+
TensorFactory<DTYPE> tf_dtype;
196+
using CTYPE = typename decltype(tf_dtype)::ctype;
197+
const auto infinity = std::numeric_limits<CTYPE>::infinity();
198+
const auto nan = std::numeric_limits<CTYPE>::quiet_NaN();
199+
200+
// clang-format off
201+
Tensor in = tf_dtype.make(
202+
{2, 3, 4},
203+
{
204+
0, 1, 2, infinity,
205+
infinity, -infinity, 1, 0,
206+
nan, infinity, -infinity, 2,
207+
208+
nan, nan, 1, 0,
209+
0, infinity, nan, 4,
210+
1, nan, 3.14, 2,
211+
});
212+
// clang-format on
213+
214+
Tensor out = tf_dtype.zeros({2, 3, 1});
215+
int64_t dims[1] = {-1};
216+
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
217+
op_amax_out(in, dim_list, /*keepdim=*/true, out);
218+
// clang-format off
219+
EXPECT_TENSOR_CLOSE(
220+
out, tf_dtype.make({2, 3, 1}, {infinity, infinity, nan, nan, nan, nan}));
221+
// clang-format on
222+
}
192223
};
193224

194225
template <>
@@ -280,32 +311,13 @@ TEST_F(OpAmaxOutTest, MismatchedDTypesDies) {
280311

281312
TEST_F(OpAmaxOutTest, AllRealInputOutputPasses) {
282313
#define TEST_ENTRY(ctype, dtype) test_amax_out_dtype<ScalarType::dtype>();
283-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
314+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
284315
#undef TEST_ENTRY
285316
}
286317

287318
TEST_F(OpAmaxOutTest, InfinityAndNANTest) {
288-
TensorFactory<ScalarType::Float> tf_float;
289-
// clang-format off
290-
Tensor in = tf_float.make(
291-
{2, 3, 4},
292-
{
293-
0, 1, 2, INFINITY,
294-
INFINITY, -INFINITY, 1, 0,
295-
NAN, INFINITY, -INFINITY, 2,
296-
297-
NAN, NAN, 1, 0,
298-
0, INFINITY, NAN, 4,
299-
1, NAN, 3.14, 2,
300-
});
301-
// clang-format on
302-
303-
Tensor out = tf_float.zeros({2, 3, 1});
304-
int64_t dims[1] = {-1};
305-
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
306-
op_amax_out(in, dim_list, /*keepdim=*/true, out);
307-
// clang-format off
308-
EXPECT_TENSOR_CLOSE(
309-
out, tf_float.make({2, 3, 1}, {INFINITY, INFINITY, NAN, NAN, NAN, NAN}));
310-
// clang-format on
319+
#define TEST_ENTRY(ctype, dtype) \
320+
test_amax_out_infinity_nan<ScalarType::dtype>();
321+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
322+
#undef TEST_ENTRY
311323
}

kernels/test/op_amin_test.cpp

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,36 @@ class OpAminOutTest : public OperatorTest {
189189
op_amin_out(in, empty_dim_list, /*keepdim=*/false, out);
190190
EXPECT_TENSOR_CLOSE(out, tf.make({}, {2}));
191191
}
192+
193+
template <ScalarType DTYPE>
194+
void test_amin_out_infinity_nan() {
195+
TensorFactory<DTYPE> tf_dtype;
196+
using CTYPE = typename decltype(tf_dtype)::ctype;
197+
const auto infinity = std::numeric_limits<CTYPE>::infinity();
198+
const auto nan = std::numeric_limits<CTYPE>::quiet_NaN();
199+
// clang-format off
200+
Tensor in = tf_dtype.make(
201+
{2, 3, 4},
202+
{
203+
0, 1, 2, infinity,
204+
infinity, -infinity, 1, 0,
205+
nan, infinity, -infinity, 2,
206+
207+
nan, nan, 1, 0,
208+
0, infinity, nan, 4,
209+
1, nan, 3.14, 2,
210+
});
211+
// clang-format on
212+
213+
Tensor out = tf_dtype.zeros({2, 3, 1});
214+
int64_t dims[1] = {-1};
215+
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
216+
op_amin_out(in, dim_list, /*keepdim=*/true, out);
217+
// clang-format off
218+
EXPECT_TENSOR_CLOSE(
219+
out, tf_dtype.make({2, 3, 1}, {0, -infinity, nan, nan, nan, nan}));
220+
// clang-format on
221+
}
192222
};
193223

194224
template <>
@@ -280,32 +310,13 @@ TEST_F(OpAminOutTest, MismatchedDTypesDies) {
280310

281311
TEST_F(OpAminOutTest, AllRealInputOutputPasses) {
282312
#define TEST_ENTRY(ctype, dtype) test_amin_out_dtype<ScalarType::dtype>();
283-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
313+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
284314
#undef TEST_ENTRY
285315
}
286316

287317
TEST_F(OpAminOutTest, InfinityAndNANTest) {
288-
TensorFactory<ScalarType::Float> tf_float;
289-
// clang-format off
290-
Tensor in = tf_float.make(
291-
{2, 3, 4},
292-
{
293-
0, 1, 2, INFINITY,
294-
INFINITY, -INFINITY, 1, 0,
295-
NAN, INFINITY, -INFINITY, 2,
296-
297-
NAN, NAN, 1, 0,
298-
0, INFINITY, NAN, 4,
299-
1, NAN, 3.14, 2,
300-
});
301-
// clang-format on
302-
303-
Tensor out = tf_float.zeros({2, 3, 1});
304-
int64_t dims[1] = {-1};
305-
ArrayRef<int64_t> dim_list{ArrayRef<int64_t>{dims, 1}};
306-
op_amin_out(in, dim_list, /*keepdim=*/true, out);
307-
// clang-format off
308-
EXPECT_TENSOR_CLOSE(
309-
out, tf_float.make({2, 3, 1}, {0, -INFINITY, NAN, NAN, NAN, NAN}));
310-
// clang-format on
318+
#define TEST_ENTRY(ctype, dtype) \
319+
test_amin_out_infinity_nan<ScalarType::dtype>();
320+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
321+
#undef TEST_ENTRY
311322
}

0 commit comments

Comments
 (0)