Skip to content

Commit 1b24530

Browse files
committed
tests: benchdnn: update post-ops calls to support binary select op
1 parent 3379512 commit 1b24530

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

tests/benchdnn/dnn_types.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -629,12 +629,12 @@ bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
629629
return kind > ELTWISE_START && kind < ELTWISE_END;
630630
}
631631
bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
632-
// binary select is a ternary operation and not currently
633-
// supported in post-ops for the binary primitive
634-
// TODO: add post-ops support for binary select operation
635632
return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END
636633
&& kind != pk_t::SELECT;
637634
}
635+
bool attr_t::post_ops_t::entry_t::is_binary_kind_with_ternary_op() const {
636+
return kind == pk_t::SELECT;
637+
}
638638
bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
639639
return kind == PRELU;
640640
}
@@ -1161,9 +1161,17 @@ dnnl_primitive_attr_t create_dnnl_attr(
11611161
} else if (e.is_binary_kind()) {
11621162
const auto &src1_md = attr_args.get_md(
11631163
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
1164+
const auto &src2_md = attr_args.get_md(
1165+
(DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2));
11641166
assert(query_md_ndims(src1_md) != 0);
1167+
1168+
if (e.is_binary_kind_with_ternary_op())
1169+
assert(query_md_ndims(src2_md) != 0);
1170+
1171+
// temporarily updating function to test without breaking
11651172
DNN_SAFE_V(dnnl_post_ops_append_binary(
1166-
ops, e.binary.alg, src1_md));
1173+
ops, e.binary.alg, src1_md, src2_md));
1174+
11671175
} else if (e.is_prelu_kind()) {
11681176
const auto &policy = e.prelu.policy;
11691177
const auto mask = attr_t::get_default_mask(policy);

tests/benchdnn/dnn_types.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2017-2024 Intel Corporation
2+
* Copyright 2017-2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -336,6 +336,7 @@ struct attr_t {
336336
bool is_convolution_kind() const;
337337
bool is_eltwise_kind() const;
338338
bool is_binary_kind() const;
339+
bool is_binary_kind_with_ternary_op() const;
339340
bool is_prelu_kind() const;
340341
};
341342

0 commit comments

Comments
 (0)