11/* ******************************************************************************
2- * Copyright 2017-2024 Intel Corporation
2+ * Copyright 2017-2025 Intel Corporation
33*
44* Licensed under the Apache License, Version 2.0 (the "License");
55* you may not use this file except in compliance with the License.
@@ -582,6 +582,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
582582 const auto &e = this ->entry [idx];
583583 int mask = -1 ;
584584 int arg = DNNL_ARG_UNDEF;
585+
586+ int mask2 = -1 ;
587+ int arg2 = DNNL_ARG_UNDEF;
588+
585589 if (e.is_binary_kind ()) {
586590 using mask_input_t = entry_t ::binary_t ::mask_input_t ;
587591 auto mask_input = e.binary .mask_input ;
@@ -590,6 +594,12 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
590594 : policy2mask (
591595 DNNL_ARG_SRC_1, e.binary .policy , prim_kind, ndims);
592596 arg = DNNL_ARG_SRC_1;
597+
598+ if (e.is_binary_kind_with_ternary_op ()) {
599+ mask2 = e.binary .mask ;
600+ arg2 = DNNL_ARG_SRC_2;
601+ }
602+
593603 } else if (e.is_prelu_kind ()) {
594604 mask = attr_t::get_default_mask (e.prelu .policy );
595605 arg = DNNL_ARG_WEIGHTS;
@@ -599,6 +609,10 @@ std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks(
599609 assert (mask >= 0 );
600610 v_masks.emplace_back (std::make_pair (
601611 DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | arg, mask));
612+
613+ if (e.is_binary_kind_with_ternary_op ())
614+ v_masks.emplace_back (std::make_pair (
615+ DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | arg2, mask2));
602616 }
603617 return v_masks;
604618}
@@ -630,11 +644,10 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
630644 return kind > ELTWISE_START && kind < ELTWISE_END;
631645}
632646bool attr_t::post_ops_t::entry_t::is_binary_kind () const {
633- // binary select is a ternary operation and not currently
634- // supported in post-ops for the binary primitive
635- // TODO: add post-ops support for binary select operation
636- return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END
637- && kind != pk_t ::SELECT;
647+ return kind > pk_t ::BINARY_START && kind < pk_t ::BINARY_END;
648+ }
649+ bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op () const {
650+ return kind == pk_t ::SELECT;
638651}
639652bool attr_t::post_ops_t::entry_t::is_prelu_kind () const {
640653 return kind == PRELU;
@@ -1059,6 +1072,22 @@ int attr_args_t::prepare_post_ops_mds(const attr_t &attr, int ndims,
10591072 mds.emplace ((DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx)
10601073 | po_rhs_tensor_entry.arg_attr_mask ),
10611074 std::move (rhs_tensor_desc));
1075+
1076+ if (e.is_binary_kind_with_ternary_op ()) {
1077+
1078+ const post_ops_rhs_tensor_entry_t
1079+ post_op_rhs_select_tensor_entry
1080+ = {e.binary .src2_dt , 0 , e.binary .tag , DNNL_ARG_SRC_2};
1081+
1082+ auto rhs_select_tensor_desc = dnn_mem_t::init_md (ndims, dims,
1083+ e.binary .src2_dt , post_op_rhs_select_tensor_entry.tag );
1084+
1085+ mds.emplace ((DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx)
1086+ | post_op_rhs_select_tensor_entry
1087+ .arg_attr_mask ),
1088+ std::move (rhs_select_tensor_desc));
1089+ }
1090+
10621091 } else if (e.is_convolution_kind ()) {
10631092 // Update dims for post operations appended after conv_dw
10641093 conv_dw_fusion::get_fused_conv_dst_dims (ndims, e, dims, dims);
@@ -1162,9 +1191,16 @@ dnnl_primitive_attr_t create_dnnl_attr(
11621191 } else if (e.is_binary_kind ()) {
11631192 const auto &src1_md = attr_args.get_md (
11641193 (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_1));
1194+ const auto &src2_md = attr_args.get_md (
1195+ (DNNL_ARG_ATTR_MULTIPLE_POST_OP (idx) | DNNL_ARG_SRC_2));
11651196 assert (query_md_ndims (src1_md) != 0 );
1197+
1198+ if (e.is_binary_kind_with_ternary_op ())
1199+ assert (query_md_ndims (src2_md) != 0 );
1200+
11661201 DNN_SAFE_V (dnnl_post_ops_append_binary (
1167- ops, e.binary .alg , src1_md));
1202+ ops, e.binary .alg , src1_md, src2_md));
1203+
11681204 } else if (e.is_prelu_kind ()) {
11691205 const auto &policy = e.prelu .policy ;
11701206 const auto mask = attr_t::get_default_mask (policy);
@@ -1585,6 +1621,9 @@ float compute_eltwise_bwd(
15851621
15861622float compute_binary (pk_t kind, float src0, float src1, bool src2) {
15871623 // don't compute on nan, propagate it
1624+
1625+ printf (" src0: %f,\t src1: %f,\t src2: %d\n " , src0, src1, src2);
1626+
15881627 if (std::isnan (src0) || std::isnan (src1)) return NAN;
15891628
15901629 if (kind == pk_t ::ADD) {
@@ -1672,7 +1711,15 @@ void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
16721711 const auto &b = e.eltwise .beta ;
16731712 val = compute_eltwise_fwd (e.kind , val, a, b);
16741713 } else if (e.is_binary_kind ()) {
1675- val = compute_binary (e.kind , val, *it_po, false );
1714+
1715+ auto src1_val = *it_po;
1716+ bool src2_val = false ;
1717+
1718+ if (e.is_binary_kind_with_ternary_op ()) {
1719+ it_po++;
1720+ src2_val = *it_po;
1721+ }
1722+ val = compute_binary (e.kind , val, src1_val, src2_val);
16761723 it_po++;
16771724 } else if (e.is_prelu_kind ()) {
16781725 val = val > 0 ? val : val * (*it_po);
0 commit comments