@@ -582,6 +582,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
582582 const auto &e = this ->entry [idx];
583583 int mask = -1 ;
584584 int arg = DNNL_ARG_UNDEF;
585+
586+ int mask2 = -1 ;
587+ int arg2 = DNNL_ARG_UNDEF;
588+
585589 if (e.is_binary_kind ()) {
586590 using mask_input_t = entry_t ::binary_t ::mask_input_t ;
587591 auto mask_input = e.binary .mask_input ;
@@ -590,6 +594,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
590594 : policy2mask (
591595 DNNL_ARG_SRC_1, e.binary .policy , prim_kind, ndims);
592596 arg = DNNL_ARG_SRC_1;
597+
598+ if (e.is_binary_kind_with_ternary_op ()) {
599+ mask2 = e.binary .mask ;
600+ arg2 = DNNL_ARG_SRC_2;
601+ }
602+
593603 } else if (e.is_prelu_kind ()) {
594604 mask = attr_t::get_default_mask (e.prelu .policy );
595605 arg = DNNL_ARG_WEIGHTS;
@@ -599,6 +609,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
599609 assert (mask >= 0 );
600610 v_masks.emplace_back (std::make_pair (
601611 DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | arg, mask));
612+
613+ if (e.is_binary_kind_with_ternary_op ())
614+ v_masks.emplace_back (std::make_pair (
615+ DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | arg2, mask2));
602616 }
603617 return v_masks;
604618}
@@ -631,11 +645,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
631645 return kind > ELTWISE_START && kind < ELTWISE_END;
632646}
633647bool attr_t::post_ops_t::entry_t::is_binary_kind () const {
634- // binary select is a ternary operation and not currently
635- // supported in post-ops for the binary primitive
636- // TODO: add post-ops support for binary select operation
637- return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END
638- && kind != pk_t ::SELECT;
648+ return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END;
649+ }
650+ bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op () const {
651+ return kind == pk_t ::SELECT;
639652}
640653bool attr_t::post_ops_t::entry_t::is_prelu_kind () const {
641654 return kind == PRELU;
@@ -1062,6 +1075,22 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10621075 mds.emplace ((DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx)
10631076 | po_rhs_tensor_entry.arg_attr_mask ),
10641077 std::move (rhs_tensor_desc));
1078+
1079+ if (e.is_binary_kind_with_ternary_op ()) {
1080+
1081+ const post_ops_rhs_tensor_entry_t
1082+ post_op_rhs_select_tensor_entry
1083+ = {e.binary .src2_dt , 0 , e.binary .tag , DNNL_ARG_SRC_2};
1084+
1085+ auto rhs_select_tensor_desc = dnn_mem_t::init_md (ndims, dims,
1086+ e.binary .src2_dt , post_op_rhs_select_tensor_entry.tag );
1087+
1088+ mds.emplace ((DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx)
1089+ | post_op_rhs_select_tensor_entry
1090+ .arg_attr_mask ),
1091+ std::move (rhs_select_tensor_desc));
1092+ }
1093+
10651094 } else if (e.is_convolution_kind ()) {
10661095 // Update dims for post operations appended after conv_dw
10671096 conv_dw_fusion::get_fused_conv_dst_dims (ndims, e, dims, dims);
@@ -1165,9 +1194,16 @@ dnnl_primitive_attr_t create_dnnl_attr(
11651194 } else if (e.is_binary_kind ()) {
11661195 const auto &src1_md = attr_args.get_md (
11671196 (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_1));
1197+ const auto &src2_md = attr_args.get_md (
1198+ (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_2));
11681199 assert (query_md_ndims (src1_md) != 0 );
1169- DNN_SAFE_V (dnnl_post_ops_append_binary (
1170- ops, e.binary .alg , src1_md));
1200+
1201+ if (e.is_binary_kind_with_ternary_op ())
1202+ assert (query_md_ndims (src2_md) != 0 );
1203+
1204+ DNN_SAFE_V (dnnl_post_ops_append_binary_v2 (
1205+ ops, e.binary .alg , src1_md, src2_md));
1206+
11711207 } else if (e.is_prelu_kind ()) {
11721208 const auto &policy = e.prelu .policy ;
11731209 const auto mask = attr_t::get_default_mask (policy);
@@ -1675,7 +1711,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16751711 const auto &b = e.eltwise .beta ;
16761712 val = compute_eltwise_fwd (e.kind , val, a, b);
16771713 } else if (e.is_binary_kind ()) {
1678- val = compute_binary (e.kind , val, *it_po, false );
1714+
1715+ auto src1_val = *it_po;
1716+ bool src2_val = false ;
1717+
1718+ if (e.is_binary_kind_with_ternary_op ()) {
1719+ it_po++;
1720+ src2_val = *it_po;
1721+ }
1722+ val = compute_binary (e.kind , val, src1_val, src2_val);
16791723 it_po++;
16801724 } else if (e.is_prelu_kind ()) {
16811725 val = val > 0 ? val : val * (*it_po);
0 commit comments