Skip to content

Commit 4f5bc57

Browse files
committed
tests: benchdnn: update postops to support binary select op
1 parent d190ccb commit 4f5bc57

File tree

6 files changed

+83
-20
lines changed

6 files changed

+83
-20
lines changed

tests/benchdnn/dnn_types.cpp

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
582582
const auto &e = this->entry[idx];
583583
int mask = -1;
584584
int arg = DNNL_ARG_UNDEF;
585+
586+
int mask2 = -1;
587+
int arg2 = DNNL_ARG_UNDEF;
588+
585589
if (e.is_binary_kind()) {
586590
using mask_input_t = entry_t::binary_t::mask_input_t;
587591
auto mask_input = e.binary.mask_input;
@@ -590,6 +594,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
590594
: policy2mask(
591595
DNNL_ARG_SRC_1, e.binary.policy, prim_kind, ndims);
592596
arg = DNNL_ARG_SRC_1;
597+
598+
if (e.is_binary_kind_with_ternary_op()) {
599+
mask2 = e.binary.mask;
600+
arg2 = DNNL_ARG_SRC_2;
601+
}
602+
593603
} else if (e.is_prelu_kind()) {
594604
mask = attr_t::get_default_mask(e.prelu.policy);
595605
arg = DNNL_ARG_WEIGHTS;
@@ -599,6 +609,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
599609
assert(mask >= 0);
600610
v_masks.emplace_back(std::make_pair(
601611
DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg, mask));
612+
613+
if (e.is_binary_kind_with_ternary_op())
614+
v_masks.emplace_back(std::make_pair(
615+
DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg2, mask2));
602616
}
603617
return v_masks;
604618
}
@@ -631,11 +645,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
631645
return kind > ELTWISE_START && kind < ELTWISE_END;
632646
}
633647
bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
634-
// binary select is a ternary operation and not currently
635-
// supported in post-ops for the binary primitive
636-
// TODO: add post-ops support for binary select operation
637-
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END
638-
&& kind != pk_t::SELECT;
648+
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END;
649+
}
650+
bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op() const {
651+
return kind == pk_t::SELECT;
639652
}
640653
bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
641654
return kind == PRELU;
@@ -1062,6 +1075,22 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10621075
mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
10631076
| po_rhs_tensor_entry.arg_attr_mask),
10641077
std::move(rhs_tensor_desc));
1078+
1079+
if (e.is_binary_kind_with_ternary_op()) {
1080+
1081+
const post_ops_rhs_tensor_entry_t
1082+
post_op_rhs_select_tensor_entry
1083+
= {e.binary.src2_dt, 0, e.binary.tag, DNNL_ARG_SRC_2};
1084+
1085+
auto rhs_select_tensor_desc = dnn_mem_t::init_md(ndims, dims,
1086+
e.binary.src2_dt, post_op_rhs_select_tensor_entry.tag);
1087+
1088+
mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
1089+
| post_op_rhs_select_tensor_entry
1090+
.arg_attr_mask),
1091+
std::move(rhs_select_tensor_desc));
1092+
}
1093+
10651094
} else if (e.is_convolution_kind()) {
10661095
// Update dims for post operations appended after conv_dw
10671096
conv_dw_fusion::get_fused_conv_dst_dims(ndims, e, dims, dims);
@@ -1165,9 +1194,16 @@ dnnl_primitive_attr_t create_dnnl_attr(
11651194
} else if (e.is_binary_kind()) {
11661195
const auto &src1_md = attr_args.get_md(
11671196
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
1197+
const auto &src2_md = attr_args.get_md(
1198+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2));
11681199
assert(query_md_ndims(src1_md) != 0);
1169-
DNN_SAFE_V(dnnl_post_ops_append_binary(
1170-
ops, e.binary.alg, src1_md));
1200+
1201+
if (e.is_binary_kind_with_ternary_op())
1202+
assert(query_md_ndims(src2_md) != 0);
1203+
1204+
DNN_SAFE_V(dnnl_post_ops_append_binary_v2(
1205+
ops, e.binary.alg, src1_md, src2_md));
1206+
11711207
} else if (e.is_prelu_kind()) {
11721208
const auto &policy = e.prelu.policy;
11731209
const auto mask = attr_t::get_default_mask(policy);
@@ -1675,7 +1711,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16751711
const auto &b = e.eltwise.beta;
16761712
val = compute_eltwise_fwd(e.kind, val, a, b);
16771713
} else if (e.is_binary_kind()) {
1678-
val = compute_binary(e.kind, val, *it_po, false);
1714+
1715+
auto src1_val = *it_po;
1716+
bool src2_val = false;
1717+
1718+
if (e.is_binary_kind_with_ternary_op()) {
1719+
it_po++;
1720+
src2_val = *it_po;
1721+
}
1722+
val = compute_binary(e.kind, val, src1_val, src2_val);
16791723
it_po++;
16801724
} else if (e.is_prelu_kind()) {
16811725
val = val > 0 ? val : val * (*it_po);

tests/benchdnn/dnn_types.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ struct attr_t {
323323

324324
dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
325325
dnnl_data_type_t src1_dt = dnnl_data_type_undef;
326+
dnnl_data_type_t src2_dt = dnnl_data_type_undef;
326327
policy_t policy = policy_t::COMMON;
327328
int64_t mask = -1;
328329
mask_input_t mask_input = mask_input_t::none;
@@ -336,6 +337,7 @@ struct attr_t {
336337
bool is_convolution_kind() const;
337338
bool is_eltwise_kind() const;
338339
bool is_binary_kind() const;
340+
bool is_binary_kind_with_ternary_op() const;
339341
bool is_prelu_kind() const;
340342
};
341343

tests/benchdnn/dnnl_common.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,15 +1722,25 @@ int init_ref_memory_args_default_case(int exec_arg, dnn_mem_t &mem,
17221722
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
17231723
assert(bin_po_idx < attr.post_ops.len());
17241724
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
1725+
const bool is_src2_mem
1726+
= ((exec_arg & DNNL_ARG_SRC_2) == DNNL_ARG_SRC_2);
1727+
17251728
// Binary post-op filling.
1726-
fill_cfg_t def_binary_cfg(mem.dt(), -16.f, 16.f, /* int = */ true,
1727-
alg, "def_binary_post_op");
1728-
const auto it = fill_cfg_map.find(DNNL_ARG_SRC_1);
1729-
const bool has_external_cfg = it != fill_cfg_map.end();
1730-
const fill_cfg_t &binary_fill_cfg
1731-
= has_external_cfg ? (*it).second : def_binary_cfg;
1732-
TIME_FILL(SAFE(fill_random_real(mem, ref_mem, res, binary_fill_cfg),
1733-
WARN));
1729+
if (!is_src2_mem
1730+
|| attr.post_ops.entry[bin_po_idx]
1731+
.is_binary_kind_with_ternary_op()) {
1732+
const int f_min = is_src2_mem ? 0 : -16.f;
1733+
fill_cfg_t def_binary_cfg(mem.dt(), f_min, 16.f,
1734+
/* int = */ true, alg, "def_binary_post_op");
1735+
const auto it = fill_cfg_map.find(
1736+
is_src2_mem ? DNNL_ARG_SRC_2 : DNNL_ARG_SRC_1);
1737+
const bool has_external_cfg = it != fill_cfg_map.end();
1738+
const fill_cfg_t &binary_fill_cfg
1739+
= has_external_cfg ? (*it).second : def_binary_cfg;
1740+
TIME_FILL(SAFE(
1741+
fill_random_real(mem, ref_mem, res, binary_fill_cfg),
1742+
WARN));
1743+
}
17341744
} else if (exec_arg & DNNL_ARG_WEIGHTS) {
17351745
// Prelu post-op filling.
17361746
fill_cfg_t def_prelu_fill_cfg(mem.dt(), -2.f, 2.f, /* int = */ true,

tests/benchdnn/dnnl_common.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -868,9 +868,13 @@ void init_memory_args(dnn_mem_map_t &mem_map, const prb_t *prb,
868868
for (int idx = 0; idx < dnnl_post_ops_len(const_po); ++idx) {
869869
if (dnnl_post_ops_get_kind(const_po, idx) != dnnl_binary) continue;
870870

871-
int po_arg = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
872-
const auto &po_md = query_md(const_pd, po_arg);
873-
mem_map.emplace(po_arg, dnn_mem_t(po_md, test_engine));
871+
int po_arg1 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
872+
const auto &po_md1 = query_md(const_pd, po_arg1);
873+
mem_map.emplace(po_arg1, dnn_mem_t(po_md1, test_engine));
874+
875+
int po_arg2 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2;
876+
const auto &po_md2 = query_md(const_pd, po_arg2);
877+
mem_map.emplace(po_arg2, dnn_mem_t(po_md2, test_engine));
874878
}
875879

876880
// Prelu post-op.

tests/benchdnn/self/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

tests/benchdnn/utils/parser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ attr_t::post_ops_t parse_attr_post_ops_func(const std::string &s) {
217217
} else if (e.is_binary_kind()) {
218218
const auto dt_str = get_substr(subs, subs_pos, ':');
219219
e.binary.src1_dt = str2dt(dt_str.c_str());
220+
221+
if (e.is_binary_kind_with_ternary_op()) e.binary.src2_dt = dnnl_s8;
222+
220223
if (e.binary.src1_dt == dnnl_data_type_undef) {
221224
BENCHDNN_PRINT(0, "%s \'%s\' %s\n",
222225
"Error: binary post-op data type", dt_str.c_str(),

0 commit comments

Comments
 (0)