@@ -598,6 +598,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
598598
599599 assert (mask >= 0 );
600600 v_masks.emplace_back (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | arg, mask);
601+
602+ // there is no broadcasting support for the ternary src2 input, hence
603+ // no mask is required.
604+ if (e.is_binary_kind_with_ternary_op ())
605+ v_masks.emplace_back (
606+ DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_2, -1 );
601607 }
602608 return v_masks;
603609}
@@ -630,11 +636,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
630636 return kind > ELTWISE_START && kind < ELTWISE_END;
631637}
632638bool attr_t::post_ops_t::entry_t::is_binary_kind () const {
633- // binary select is a ternary operation and not currently
634- // supported in post-ops for the binary primitive
635- // TODO: add post-ops support for binary select operation
636- return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END
637- && kind != pk_t ::SELECT;
639+ return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END;
640+ }
641+ bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op () const {
642+ return kind == pk_t ::SELECT;
638643}
639644bool attr_t::post_ops_t::entry_t::is_prelu_kind () const {
640645 return kind == PRELU;
@@ -1065,6 +1070,15 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10651070 mds.emplace ((DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx)
10661071 | po_rhs_tensor_entry.arg_attr_mask ),
10671072 std::move (rhs_tensor_desc));
1073+
1074+ if (e.is_binary_kind_with_ternary_op ()) {
1075+ auto rhs_select_tensor_desc = dnn_mem_t::init_md (
1076+ ndims, dims, e.binary .src2_dt , tag::any);
1077+ mds.emplace (
1078+ (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_2),
1079+ std::move (rhs_select_tensor_desc));
1080+ }
1081+
10681082 } else if (e.is_convolution_kind ()) {
10691083 // Update dims for post operations appended after conv_dw
10701084 conv_dw_fusion::get_fused_conv_dst_dims (ndims, e, dims, dims);
@@ -1168,9 +1182,17 @@ dnnl_primitive_attr_t create_dnnl_attr(
11681182 } else if (e.is_binary_kind ()) {
11691183 const auto &src1_md = attr_args.get_md (
11701184 (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_1));
1185+ const auto &src2_md = attr_args.get_md (
1186+ (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_2));
11711187 assert (query_md_ndims (src1_md) != 0 );
1172- DNN_SAFE_V (dnnl_post_ops_append_binary (
1173- ops, e.binary .alg , src1_md));
1188+
1189+ if (e.is_binary_kind_with_ternary_op ()) {
1190+ assert (query_md_ndims (src2_md) != 0 );
1191+ }
1192+
1193+ DNN_SAFE_V (dnnl_post_ops_append_binary_v2 (
1194+ ops, e.binary .alg , src1_md, src2_md));
1195+
11741196 } else if (e.is_prelu_kind ()) {
11751197 const auto &policy = e.prelu .policy ;
11761198 const auto mask = attr_t::get_default_mask (policy);
@@ -1678,7 +1700,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16781700 const auto &b = e.eltwise .beta ;
16791701 val = compute_eltwise_fwd (e.kind , val, a, b);
16801702 } else if (e.is_binary_kind ()) {
1681- val = compute_binary (e.kind , val, *it_po, false );
1703+
1704+ auto src1_val = *it_po;
1705+ bool src2_val = false ;
1706+
1707+ if (e.is_binary_kind_with_ternary_op ()) {
1708+ it_po++;
1709+ src2_val = static_cast <bool >(*it_po);
1710+ }
1711+ val = compute_binary (e.kind , val, src1_val, src2_val);
16821712 it_po++;
16831713 } else if (e.is_prelu_kind ()) {
16841714 val = val > 0 ? val : val * (*it_po);
0 commit comments