|
1 | 1 | /******************************************************************************* |
2 | | -* Copyright 2017-2024 Intel Corporation |
| 2 | +* Copyright 2017-2025 Intel Corporation |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
@@ -629,12 +629,12 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const { |
629 | 629 | return kind > ELTWISE_START && kind < ELTWISE_END; |
630 | 630 | } |
631 | 631 | bool attr_t::post_ops_t::entry_t::is_binary_kind() const { |
632 | | - // binary select is a ternary operation and not currently |
633 | | - // supported in post-ops for the binary primitive |
634 | | - // TODO: add post-ops support for binary select operation |
635 | 632 | return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END |
636 | 633 | && kind != pk_t::SELECT; |
637 | 634 | } |
| 635 | +bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op() const { |
| 636 | + return kind == pk_t::SELECT; |
| 637 | +} |
638 | 638 | bool attr_t::post_ops_t::entry_t::is_prelu_kind() const { |
639 | 639 | return kind == PRELU; |
640 | 640 | } |
@@ -1161,9 +1161,17 @@ dnnl_primitive_attr_t create_dnnl_attr( |
1161 | 1161 | } else if (e.is_binary_kind()) { |
1162 | 1162 | const auto &src1_md = attr_args.get_md( |
1163 | 1163 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); |
| 1164 | + const auto &src2_md = attr_args.get_md( |
| 1165 | + (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2)); |
1164 | 1166 | assert(query_md_ndims(src1_md) != 0); |
| 1167 | + |
| 1168 | + if (e.is_binary_kind_with_ternary_op()) |
| 1169 | + assert(query_md_ndims(src2_md) != 0); |
| 1170 | + |
| 1171 | + // temporarily updating function to test without breaking |
1165 | 1172 | DNN_SAFE_V(dnnl_post_ops_append_binary( |
1166 | | - ops, e.binary.alg, src1_md)); |
| 1173 | + ops, e.binary.alg, src1_md, src2_md)); |
| 1174 | + |
1167 | 1175 | } else if (e.is_prelu_kind()) { |
1168 | 1176 | const auto &policy = e.prelu.policy; |
1169 | 1177 | const auto mask = attr_t::get_default_mask(policy); |
|
0 commit comments