Skip to content

Commit 7db6ca1

Browse files
committed
tests: benchdnn: update post-ops to support binary select op
1 parent b05e6ab commit 7db6ca1

File tree

6 files changed

+93
-24
lines changed

6 files changed

+93
-24
lines changed

tests/benchdnn/dnn_types.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -582,6 +582,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
582582
const auto &e = this->entry[idx];
583583
int mask = -1;
584584
int arg = DNNL_ARG_UNDEF;
585+
586+
int mask2 = -1;
587+
int arg2 = DNNL_ARG_UNDEF;
588+
585589
if (e.is_binary_kind()) {
586590
using mask_input_t = entry_t::binary_t::mask_input_t;
587591
auto mask_input = e.binary.mask_input;
@@ -590,6 +594,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
590594
: policy2mask(
591595
DNNL_ARG_SRC_1, e.binary.policy, prim_kind, ndims);
592596
arg = DNNL_ARG_SRC_1;
597+
598+
if (e.is_binary_kind_with_ternary_op()) {
599+
mask2 = e.binary.mask;
600+
arg2 = DNNL_ARG_SRC_2;
601+
}
602+
593603
} else if (e.is_prelu_kind()) {
594604
mask = attr_t::get_default_mask(e.prelu.policy);
595605
arg = DNNL_ARG_WEIGHTS;
@@ -599,6 +609,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
599609
assert(mask >= 0);
600610
v_masks.emplace_back(std::make_pair(
601611
DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg, mask));
612+
613+
if (e.is_binary_kind_with_ternary_op())
614+
v_masks.emplace_back(std::make_pair(
615+
DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg2, mask2));
602616
}
603617
return v_masks;
604618
}
@@ -630,11 +644,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
630644
return kind > ELTWISE_START && kind < ELTWISE_END;
631645
}
632646
bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
633-
// binary select is a ternary operation and not currently
634-
// supported in post-ops for the binary primitive
635-
// TODO: add post-ops support for binary select operation
636-
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END
637-
&& kind != pk_t::SELECT;
647+
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END;
648+
}
649+
bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op() const {
650+
return kind == pk_t::SELECT;
638651
}
639652
bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
640653
return kind == PRELU;
@@ -1059,6 +1072,22 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10591072
mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
10601073
| po_rhs_tensor_entry.arg_attr_mask),
10611074
std::move(rhs_tensor_desc));
1075+
1076+
if (e.is_binary_kind_with_ternary_op()) {
1077+
1078+
const post_ops_rhs_tensor_entry_t
1079+
post_op_rhs_select_tensor_entry
1080+
= {e.binary.src2_dt, 0, e.binary.tag, DNNL_ARG_SRC_2};
1081+
1082+
auto rhs_select_tensor_desc = dnn_mem_t::init_md(ndims, dims,
1083+
e.binary.src2_dt, post_op_rhs_select_tensor_entry.tag);
1084+
1085+
mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
1086+
| post_op_rhs_select_tensor_entry
1087+
.arg_attr_mask),
1088+
std::move(rhs_select_tensor_desc));
1089+
}
1090+
10621091
} else if (e.is_convolution_kind()) {
10631092
// Update dims for post operations appended after conv_dw
10641093
conv_dw_fusion::get_fused_conv_dst_dims(ndims, e, dims, dims);
@@ -1162,9 +1191,16 @@ dnnl_primitive_attr_t create_dnnl_attr(
11621191
} else if (e.is_binary_kind()) {
11631192
const auto &src1_md = attr_args.get_md(
11641193
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
1194+
const auto &src2_md = attr_args.get_md(
1195+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2));
11651196
assert(query_md_ndims(src1_md) != 0);
1197+
1198+
if (e.is_binary_kind_with_ternary_op())
1199+
assert(query_md_ndims(src2_md) != 0);
1200+
11661201
DNN_SAFE_V(dnnl_post_ops_append_binary(
1167-
ops, e.binary.alg, src1_md));
1202+
ops, e.binary.alg, src1_md, src2_md));
1203+
11681204
} else if (e.is_prelu_kind()) {
11691205
const auto &policy = e.prelu.policy;
11701206
const auto mask = attr_t::get_default_mask(policy);
@@ -1585,6 +1621,9 @@ float compute_eltwise_bwd(
15851621

15861622
float compute_binary(pk_t kind, float src0, float src1, bool src2) {
15871623
// don't compute on nan, propagate it
1624+
1625+
printf("src0: %f,\tsrc1: %f,\tsrc2: %d\n", src0, src1, src2);
1626+
15881627
if (std::isnan(src0) || std::isnan(src1)) return NAN;
15891628

15901629
if (kind == pk_t::ADD) {
@@ -1672,7 +1711,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16721711
const auto &b = e.eltwise.beta;
16731712
val = compute_eltwise_fwd(e.kind, val, a, b);
16741713
} else if (e.is_binary_kind()) {
1675-
val = compute_binary(e.kind, val, *it_po, false);
1714+
1715+
auto src1_val = *it_po;
1716+
bool src2_val = false;
1717+
1718+
if (e.is_binary_kind_with_ternary_op()) {
1719+
it_po++;
1720+
src2_val = *it_po;
1721+
}
1722+
val = compute_binary(e.kind, val, src1_val, src2_val);
16761723
it_po++;
16771724
} else if (e.is_prelu_kind()) {
16781725
val = val > 0 ? val : val * (*it_po);

tests/benchdnn/dnn_types.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -323,6 +323,7 @@ struct attr_t {
323323

324324
dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
325325
dnnl_data_type_t src1_dt = dnnl_data_type_undef;
326+
dnnl_data_type_t src2_dt = dnnl_data_type_undef;
326327
policy_t policy = policy_t::COMMON;
327328
int64_t mask = -1;
328329
mask_input_t mask_input = mask_input_t::none;
@@ -336,6 +337,7 @@ struct attr_t {
336337
bool is_convolution_kind() const;
337338
bool is_eltwise_kind() const;
338339
bool is_binary_kind() const;
340+
bool is_binary_kind_with_ternary_op() const;
339341
bool is_prelu_kind() const;
340342
};
341343

tests/benchdnn/dnnl_common.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -1644,15 +1644,25 @@ int init_ref_memory_args_default_case(int exec_arg, dnn_mem_t &mem,
16441644
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
16451645
assert(bin_po_idx < attr.post_ops.len());
16461646
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
1647+
const bool is_src2_mem
1648+
= ((exec_arg & DNNL_ARG_SRC_2) == DNNL_ARG_SRC_2);
1649+
16471650
// Binary post-op filling.
1648-
fill_cfg_t def_binary_cfg(mem.dt(), -16.f, 16.f, /* int = */ true,
1649-
alg, "def_binary_post_op");
1650-
const auto it = fill_cfg_map.find(DNNL_ARG_SRC_1);
1651-
const bool has_external_cfg = it != fill_cfg_map.end();
1652-
const fill_cfg_t &binary_fill_cfg
1653-
= has_external_cfg ? (*it).second : def_binary_cfg;
1654-
TIME_FILL(SAFE(fill_random_real(mem, ref_mem, res, binary_fill_cfg),
1655-
WARN));
1651+
if (!is_src2_mem
1652+
|| attr.post_ops.entry[bin_po_idx]
1653+
.is_binary_kind_with_ternary_op()) {
1654+
const int f_min = is_src2_mem ? 0 : -16.f;
1655+
fill_cfg_t def_binary_cfg(mem.dt(), f_min, 16.f,
1656+
/* int = */ true, alg, "def_binary_post_op");
1657+
const auto it = fill_cfg_map.find(
1658+
is_src2_mem ? DNNL_ARG_SRC_2 : DNNL_ARG_SRC_1);
1659+
const bool has_external_cfg = it != fill_cfg_map.end();
1660+
const fill_cfg_t &binary_fill_cfg
1661+
= has_external_cfg ? (*it).second : def_binary_cfg;
1662+
TIME_FILL(SAFE(
1663+
fill_random_real(mem, ref_mem, res, binary_fill_cfg),
1664+
WARN));
1665+
}
16561666
} else if (exec_arg & DNNL_ARG_WEIGHTS) {
16571667
// Prelu post-op filling.
16581668
fill_cfg_t def_prelu_fill_cfg(mem.dt(), -2.f, 2.f, /* int = */ true,

tests/benchdnn/dnnl_common.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -861,9 +861,13 @@ void init_memory_args(dnn_mem_map_t &mem_map, const prb_t *prb,
861861
for (int idx = 0; idx < dnnl_post_ops_len(const_po); ++idx) {
862862
if (dnnl_post_ops_get_kind(const_po, idx) != dnnl_binary) continue;
863863

864-
int po_arg = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
865-
const auto &po_md = query_md(const_pd, po_arg);
866-
mem_map.emplace(po_arg, dnn_mem_t(po_md, test_engine));
864+
int po_arg1 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
865+
const auto &po_md1 = query_md(const_pd, po_arg1);
866+
mem_map.emplace(po_arg1, dnn_mem_t(po_md1, test_engine));
867+
868+
int po_arg2 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2;
869+
const auto &po_md2 = query_md(const_pd, po_arg2);
870+
mem_map.emplace(po_arg2, dnn_mem_t(po_md2, test_engine));
867871
}
868872

869873
// Prelu post-op.

tests/benchdnn/self/common.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -344,6 +344,9 @@ void append_binary(attr_t::post_ops_t &po, pk_t akind, dnnl_data_type_t src_dt1,
344344
attr_t::post_ops_t::entry_t e(akind);
345345
e.binary.alg = attr_t::post_ops_t::kind2dnnl_kind(akind);
346346
e.binary.src1_dt = src_dt1;
347+
348+
if (e.is_binary_kind_with_ternary_op()) e.binary.src2_dt = dnnl_s8;
349+
347350
e.binary.mask_input = mask_input;
348351
e.binary.mask = mask;
349352
e.binary.policy = policy;

tests/benchdnn/utils/parser.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2019-2024 Intel Corporation
2+
* Copyright 2019-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -217,6 +217,9 @@ attr_t::post_ops_t parse_attr_post_ops_func(const std::string &s) {
217217
} else if (e.is_binary_kind()) {
218218
const auto dt_str = get_substr(subs, subs_pos, ':');
219219
e.binary.src1_dt = str2dt(dt_str.c_str());
220+
221+
if (e.is_binary_kind_with_ternary_op()) e.binary.src2_dt = dnnl_s8;
222+
220223
if (e.binary.src1_dt == dnnl_data_type_undef) {
221224
BENCHDNN_PRINT(0, "%s \'%s\' %s\n",
222225
"Error: binary post-op data type", dt_str.c_str(),

0 commit comments

Comments
 (0)