Skip to content

Commit d36e9a8

Browse files
committed
tests: benchdnn: update postops to support binary select op
1 parent 1713917 commit d36e9a8

File tree

13 files changed

+133
-30
lines changed

13 files changed

+133
-30
lines changed

tests/benchdnn/dnn_types.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
598598

599599
assert(mask >= 0);
600600
v_masks.emplace_back(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg, mask);
601+
602+
// there is no broadcasting support for the ternary src2 input, hence
603+
// no mask is required.
604+
if (e.is_binary_kind_with_ternary_op())
605+
v_masks.emplace_back(
606+
DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2, -1);
601607
}
602608
return v_masks;
603609
}
@@ -630,11 +636,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
630636
return kind > ELTWISE_START && kind < ELTWISE_END;
631637
}
632638
bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
633-
// binary select is a ternary operation and not currently
634-
// supported in post-ops for the binary primitive
635-
// TODO: add post-ops support for binary select operation
636-
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END
637-
&& kind != pk_t::SELECT;
639+
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END;
640+
}
641+
bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op() const {
642+
return kind == pk_t::SELECT;
638643
}
639644
bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
640645
return kind == PRELU;
@@ -1065,6 +1070,15 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10651070
mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
10661071
| po_rhs_tensor_entry.arg_attr_mask),
10671072
std::move(rhs_tensor_desc));
1073+
1074+
if (e.is_binary_kind_with_ternary_op()) {
1075+
auto rhs_select_tensor_desc = dnn_mem_t::init_md(
1076+
ndims, dims, e.binary.src2_dt, tag::any);
1077+
mds.emplace(
1078+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2),
1079+
std::move(rhs_select_tensor_desc));
1080+
}
1081+
10681082
} else if (e.is_convolution_kind()) {
10691083
// Update dims for post operations appended after conv_dw
10701084
conv_dw_fusion::get_fused_conv_dst_dims(ndims, e, dims, dims);
@@ -1168,9 +1182,17 @@ dnnl_primitive_attr_t create_dnnl_attr(
11681182
} else if (e.is_binary_kind()) {
11691183
const auto &src1_md = attr_args.get_md(
11701184
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
1185+
const auto &src2_md = attr_args.get_md(
1186+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2));
11711187
assert(query_md_ndims(src1_md) != 0);
1172-
DNN_SAFE_V(dnnl_post_ops_append_binary(
1173-
ops, e.binary.alg, src1_md));
1188+
1189+
if (e.is_binary_kind_with_ternary_op()) {
1190+
assert(query_md_ndims(src2_md) != 0);
1191+
}
1192+
1193+
DNN_SAFE_V(dnnl_post_ops_append_binary_v2(
1194+
ops, e.binary.alg, src1_md, src2_md));
1195+
11741196
} else if (e.is_prelu_kind()) {
11751197
const auto &policy = e.prelu.policy;
11761198
const auto mask = attr_t::get_default_mask(policy);
@@ -1678,7 +1700,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16781700
const auto &b = e.eltwise.beta;
16791701
val = compute_eltwise_fwd(e.kind, val, a, b);
16801702
} else if (e.is_binary_kind()) {
1681-
val = compute_binary(e.kind, val, *it_po, false);
1703+
1704+
auto src1_val = *it_po;
1705+
bool src2_val = false;
1706+
1707+
if (e.is_binary_kind_with_ternary_op()) {
1708+
it_po++;
1709+
src2_val = static_cast<bool>(*it_po);
1710+
}
1711+
val = compute_binary(e.kind, val, src1_val, src2_val);
16821712
it_po++;
16831713
} else if (e.is_prelu_kind()) {
16841714
val = val > 0 ? val : val * (*it_po);

tests/benchdnn/dnn_types.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ struct attr_t {
320320

321321
dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
322322
dnnl_data_type_t src1_dt = dnnl_data_type_undef;
323+
dnnl_data_type_t src2_dt = dnnl_data_type_undef;
323324
policy_t policy = policy_t::COMMON;
324325
int64_t mask = -1;
325326
mask_input_t mask_input = mask_input_t::none;
@@ -333,6 +334,7 @@ struct attr_t {
333334
bool is_convolution_kind() const;
334335
bool is_eltwise_kind() const;
335336
bool is_binary_kind() const;
337+
bool is_binary_kind_with_ternary_op() const;
336338
bool is_prelu_kind() const;
337339
};
338340

tests/benchdnn/dnnl_common.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,20 +1801,42 @@ int init_ref_memory_args_default_case(int exec_arg, dnn_mem_t &mem,
18011801
const bool is_rounding_seed = (exec_arg == DNNL_ARG_ATTR_ROUNDING_SEED);
18021802

18031803
if (is_post_ops_arg) {
1804-
if (exec_arg & DNNL_ARG_SRC_1) {
1805-
const int bin_po_idx
1806-
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
1807-
assert(bin_po_idx < attr.post_ops.len());
1804+
const int bin_po_idx
1805+
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
1806+
assert(bin_po_idx < attr.post_ops.len());
1807+
const bool exact_match_for_src1_arg = !(exec_arg
1808+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
1809+
| DNNL_ARG_SRC_1));
1810+
const bool exact_match_for_src2_arg
1811+
= !(exec_arg
1812+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
1813+
| DNNL_ARG_SRC_2))
1814+
&& attr.post_ops.entry[bin_po_idx]
1815+
.is_binary_kind_with_ternary_op();
1816+
1817+
if (exact_match_for_src1_arg) {
18081818
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
1809-
// Binary post-op filling.
1810-
fill_cfg_t def_binary_cfg(mem.dt(), -16.f, 16.f, /* int = */ true,
1811-
alg, "def_binary_post_op");
1819+
// Binary post-op filling for src1 tensor
1820+
fill_cfg_t def_binary_cfg(mem.dt(), -16.f, 16.f,
1821+
/* int = */ true, alg, "def_binary_post_op_src1");
18121822
const auto it = fill_cfg_map.find(DNNL_ARG_SRC_1);
18131823
const bool has_external_cfg = it != fill_cfg_map.end();
18141824
const fill_cfg_t &binary_fill_cfg
18151825
= has_external_cfg ? (*it).second : def_binary_cfg;
18161826
TIME_FILL(SAFE(fill_random_real(mem, ref_mem, res, binary_fill_cfg),
18171827
WARN));
1828+
} else if (exact_match_for_src2_arg) {
1829+
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
1830+
// Binary post-op filling for src2 conditional tensor
1831+
// - ignored if the algorithm does not take ternary inputs
1832+
fill_cfg_t def_binary_cfg(mem.dt(), 0, 16.f,
1833+
/* int = */ true, alg, "def_binary_post_op_src2");
1834+
const auto it = fill_cfg_map.find(DNNL_ARG_SRC_2);
1835+
const bool has_external_cfg = it != fill_cfg_map.end();
1836+
const fill_cfg_t &binary_fill_cfg
1837+
= has_external_cfg ? (*it).second : def_binary_cfg;
1838+
TIME_FILL(SAFE(fill_random_real(mem, ref_mem, res, binary_fill_cfg),
1839+
WARN));
18181840
} else if (exec_arg & DNNL_ARG_WEIGHTS) {
18191841
// Prelu post-op filling.
18201842
fill_cfg_t def_prelu_fill_cfg(mem.dt(), -2.f, 2.f, /* int = */ true,

tests/benchdnn/dnnl_common.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -894,9 +894,14 @@ void init_memory_args(dnn_mem_map_t &mem_map, const prb_t *prb,
894894
for (int idx = 0; idx < dnnl_post_ops_len(const_po); ++idx) {
895895
if (dnnl_post_ops_get_kind(const_po, idx) != dnnl_binary) continue;
896896

897-
int po_arg = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
898-
const auto &po_md = query_md(const_pd, po_arg);
899-
mem_map.emplace(po_arg, dnn_mem_t(po_md, test_engine));
897+
int po_arg1 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
898+
const auto &po_md1 = query_md(const_pd, po_arg1);
899+
mem_map.emplace(po_arg1, dnn_mem_t(po_md1, test_engine));
900+
901+
if (!query_po_alg_kind(const_po, idx, dnnl_binary_select)) continue;
902+
int po_arg2 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2;
903+
const auto &po_md2 = query_md(const_pd, po_arg2);
904+
mem_map.emplace(po_arg2, dnn_mem_t(po_md2, test_engine));
900905
}
901906

902907
// Prelu post-op.

tests/benchdnn/gnorm/gnorm.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,13 @@ fill_cfg_t binary_po_fill_cfg(
608608
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
609609
assert(bin_po_idx < attr.post_ops.len());
610610
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
611-
cfg = fill_cfg_t(mem.dt(), 4.f, 16.f, /* int = */ true, alg,
612-
"gnorm_binary_post_op");
611+
const bool is_src1_arg = !(exec_arg
612+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
613+
| DNNL_ARG_SRC_1));
614+
615+
if (is_src1_arg)
616+
cfg = fill_cfg_t(mem.dt(), 4.f, 16.f, /* int = */ true, alg,
617+
"gnorm_binary_post_op");
613618
}
614619
return cfg;
615620
}

tests/benchdnn/lnorm/lnorm.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,13 @@ fill_cfg_t binary_po_fill_cfg(
565565
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
566566
assert(bin_po_idx < attr.post_ops.len());
567567
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
568-
cfg = fill_cfg_t(mem.dt(), 4.f, 16.f, /* int = */ true, alg,
569-
"lnorm_binary_post_op");
568+
const bool is_src1_arg = !(exec_arg
569+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
570+
| DNNL_ARG_SRC_1));
571+
572+
if (is_src1_arg)
573+
cfg = fill_cfg_t(mem.dt(), 4.f, 16.f, /* int = */ true, alg,
574+
"lnorm_binary_post_op");
570575
}
571576
return cfg;
572577
}

tests/benchdnn/pool/pool.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,14 @@ fill_cfg_t binary_po_fill_cfg(
253253
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
254254
assert(bin_po_idx < attr.post_ops.len());
255255
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
256-
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
257-
"pooling_binary_post_op");
256+
257+
const bool is_src1_arg = !(exec_arg
258+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
259+
| DNNL_ARG_SRC_1));
260+
261+
if (is_src1_arg)
262+
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
263+
"pooling_binary_post_op");
258264
}
259265
return cfg;
260266
}

tests/benchdnn/reduction/reduction.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,13 @@ fill_cfg_t binary_po_fill_cfg(
255255
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
256256
assert(bin_po_idx < attr.post_ops.len());
257257
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
258-
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
259-
"reduction_binary_post_op");
258+
const bool is_src1_arg = !(exec_arg
259+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
260+
| DNNL_ARG_SRC_1));
261+
262+
if (is_src1_arg)
263+
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
264+
"reduction_binary_post_op");
260265
}
261266
return cfg;
262267
}

tests/benchdnn/resampling/resampling.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,13 @@ fill_cfg_t binary_po_fill_cfg(
164164
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
165165
assert(bin_po_idx < attr.post_ops.len());
166166
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
167-
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
168-
"resampling_binary_post_op");
167+
const bool is_src1_arg = !(exec_arg
168+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
169+
| DNNL_ARG_SRC_1));
170+
171+
if (is_src1_arg)
172+
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
173+
"resampling_binary_post_op");
169174
}
170175
return cfg;
171176
}

tests/benchdnn/softmax/softmax.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,13 @@ fill_cfg_t binary_po_fill_cfg(
315315
= exec_arg / DNNL_ARG_ATTR_MULTIPLE_POST_OP_BASE - 1;
316316
assert(bin_po_idx < attr.post_ops.len());
317317
const auto alg = attr.post_ops.entry[bin_po_idx].kind;
318-
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
319-
"softmax_binary_post_op");
318+
const bool is_src1_arg = !(exec_arg
319+
^ (DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_po_idx)
320+
| DNNL_ARG_SRC_1));
321+
322+
if (is_src1_arg)
323+
cfg = fill_cfg_t(mem.dt(), 0.f, 16.f, /* int = */ true, alg,
324+
"softmax_binary_post_op");
320325
}
321326
return cfg;
322327
}

0 commit comments

Comments
 (0)