Skip to content

Commit 77885c2

Browse files
committed
cpu causal_conv1d op
1 parent 88a4f85 commit 77885c2

File tree

19 files changed

+647
-4
lines changed

19 files changed

+647
-4
lines changed

src/bindings/python/src/openvino/_pyopenvino/op/__init__.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ class _PagedAttentionExtension(openvino._pyopenvino.Node):
183183
"""
184184
def __init__(self, arg0: collections.abc.Sequence[openvino._pyopenvino.Output]) -> None:
185185
...
186+
187+
class _CausalConv1D(openvino._pyopenvino.Node):
188+
"""
189+
Experimental extension for CausalConv1D operation. Use with care: no backward compatibility is guaranteed in future releases.
190+
"""
191+
def __init__(self, arg0: collections.abc.Sequence[openvino._pyopenvino.Output]) -> None:
192+
...
186193
class assign(openvino._pyopenvino.Node):
187194
"""
188195
openvino.op.assign wraps ov::op::v6::Assign

src/bindings/python/src/openvino/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from openvino._pyopenvino.op import Constant
1212
from openvino._pyopenvino.op import assign
1313
from openvino._pyopenvino.op import _PagedAttentionExtension
14+
from openvino._pyopenvino.op import _CausalConv1D
1415
from openvino._pyopenvino.op import Parameter
1516
from openvino._pyopenvino.op import if_op
1617
from openvino._pyopenvino.op import loop

src/bindings/python/src/openvino/op/__init__.pyi

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from openvino._pyopenvino.op import Constant
55
from openvino._pyopenvino.op import Parameter
66
from openvino._pyopenvino.op import Result
77
from openvino._pyopenvino.op import _PagedAttentionExtension
8+
from openvino._pyopenvino.op import _CausalConv1D
89
from openvino._pyopenvino.op import assign
910
from openvino._pyopenvino.op import if_op
1011
from openvino._pyopenvino.op import loop
@@ -15,4 +16,16 @@ from openvino._pyopenvino.op import tensor_iterator
1516
Package: openvino.op
1617
Low level wrappers for the c++ api in ov::op.
1718
"""
18-
__all__: list[str] = ['Constant', 'Parameter', 'Result', 'assign', 'if_op', 'loop', 'read_value', 'tensor_iterator', 'util']
19+
__all__: list[str] = [
20+
'Constant',
21+
'Parameter',
22+
'Result',
23+
'_PagedAttentionExtension',
24+
'_CausalConv1D',
25+
'assign',
26+
'if_op',
27+
'loop',
28+
'read_value',
29+
'tensor_iterator',
30+
'util',
31+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "pyopenvino/graph/ops/causal_conv1d.hpp"
6+
7+
#include "openvino/op/causal_conv1d.hpp"
8+
#include "pyopenvino/core/common.hpp"
9+
10+
namespace py = pybind11;
11+
12+
void regclass_graph_op_CausalConv1D(py::module m) {
13+
using ov::op::CausalConv1D;
14+
py::class_<CausalConv1D, std::shared_ptr<CausalConv1D>, ov::Node> cls(m, "_CausalConv1D");
15+
cls.doc() = "Experimental extension for CausalConv1D operation. Use with care: no backward compatibility is guaranteed in future releases.";
16+
cls.def(py::init<const ov::OutputVector&, const std::string&>(), py::arg("inputs"), py::arg("activation") = "silu");
17+
cls.def("set_activation", &CausalConv1D::set_activation);
18+
cls.def("get_activation", &CausalConv1D::get_activation);
19+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
6+
#include <pybind11/pybind11.h>
7+
8+
namespace py = pybind11;
9+
10+
void regclass_graph_op_CausalConv1D(py::module m);

src/bindings/python/src/pyopenvino/pyopenvino.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "pyopenvino/graph/ops/if.hpp"
5858
#include "pyopenvino/graph/ops/loop.hpp"
5959
#include "pyopenvino/graph/ops/paged_attention_extension.hpp"
60+
#include "pyopenvino/graph/ops/causal_conv1d.hpp"
6061
#include "pyopenvino/graph/ops/parameter.hpp"
6162
#include "pyopenvino/graph/ops/read_value.hpp"
6263
#include "pyopenvino/graph/ops/result.hpp"
@@ -242,6 +243,7 @@ PYBIND11_MODULE(_pyopenvino, m) {
242243
regclass_graph_op_Assign(m_op);
243244
regclass_graph_op_Constant(m_op);
244245
regclass_graph_op_PagedAttentionExtension(m_op);
246+
regclass_graph_op_CausalConv1D(m_op);
245247
regclass_graph_op_Parameter(m_op);
246248
regclass_graph_op_ReadValue(m_op);
247249
regclass_graph_op_Result(m_op);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
6+
#include "openvino/op/op.hpp"
7+
8+
namespace ov {
9+
namespace op {
10+
11+
// Experimental op mirroring the CPU plugin implementation. Backward compatibility is not guaranteed.
12+
class OPENVINO_API CausalConv1D : public ov::op::Op {
13+
public:
14+
OPENVINO_OP("CausalConv1D");
15+
16+
CausalConv1D() = default;
17+
CausalConv1D(const ov::OutputVector& args, const std::string& activation = "silu");
18+
19+
void validate_and_infer_types() override;
20+
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
21+
bool visit_attributes(ov::AttributeVisitor& visitor) override;
22+
23+
void set_activation(const std::string& activation);
24+
const std::string& get_activation() const;
25+
26+
private:
27+
std::string m_activation = "silu";
28+
};
29+
30+
} // namespace op
31+
} // namespace ov

src/core/src/op/causal_conv1d.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/causal_conv1d.hpp"
6+
7+
#include <algorithm>
8+
#include <array>
9+
#include <initializer_list>
10+
11+
#include "dimension_util.hpp"
12+
#include "itt.hpp"
13+
#include "openvino/core/validation_util.hpp"
14+
#include "openvino/core/attribute_visitor.hpp"
15+
16+
namespace {
17+
18+
inline void input_check(const ov::Node* node,
19+
size_t idx,
20+
const std::string_view input_name,
21+
std::initializer_list<ov::Rank>&& allowed_ranks,
22+
const std::vector<ov::element::Type>& allowed_types) {
23+
using namespace ov;
24+
using namespace ov::util;
25+
using namespace ov::element;
26+
27+
const auto& rank = node->get_input_partial_shape(idx).rank();
28+
const auto& tp = node->get_input_element_type(idx);
29+
30+
auto rank_check = [&](const Rank& r) {
31+
return r.is_dynamic() || allowed_ranks.size() == 0 || is_rank_compatible_any_of(r.get_length(), allowed_ranks);
32+
};
33+
34+
auto type_check = [&](const Type& t) {
35+
if (allowed_types.empty()) {
36+
return true;
37+
}
38+
auto it = std::find(allowed_types.begin(), allowed_types.end(), tp);
39+
return t.is_dynamic() || it != allowed_types.end();
40+
};
41+
42+
NODE_VALIDATION_CHECK(node,
43+
rank_check(rank),
44+
"Rank of `",
45+
input_name,
46+
"` input should be in [dynamic, ",
47+
join(allowed_ranks),
48+
"] list, but it is ",
49+
rank,
50+
".");
51+
52+
NODE_VALIDATION_CHECK(node,
53+
type_check(tp),
54+
"Element type of `",
55+
input_name,
56+
"` input should be in [dynamic, ",
57+
join(allowed_types),
58+
"] list, but it is ",
59+
tp,
60+
".");
61+
}
62+
63+
} // namespace
64+
65+
namespace ov {
66+
namespace op {
67+
68+
CausalConv1D::CausalConv1D(const ov::OutputVector& args, const std::string& activation)
69+
: ov::op::Op(args), m_activation(activation) {
70+
constructor_validate_and_infer_types();
71+
}
72+
73+
void CausalConv1D::validate_and_infer_types() {
74+
OV_OP_SCOPE(CausalConv1D_validate_and_infer_types);
75+
76+
NODE_VALIDATION_CHECK(this,
77+
get_input_size() == 3 || get_input_size() == 4,
78+
"CausalConv1D expects 3 or 4 inputs, but it has ",
79+
get_input_size());
80+
81+
// hidden_states: (batch, hidden_size, seq_len)
82+
input_check(this, 0, "hidden_states", {3}, {});
83+
// conv_state: (batch, hidden_size, kernel)
84+
input_check(this, 1, "conv_state", {3}, {});
85+
// weight: (hidden_size, kernel)
86+
input_check(this, 2, "weight", {2}, {});
87+
88+
if (get_input_size() == 4) {
89+
// bias: (hidden_size)
90+
input_check(this, 3, "bias", {1}, {});
91+
}
92+
93+
const auto hidden_ps = get_input_partial_shape(0);
94+
const auto state_ps = get_input_partial_shape(1);
95+
const auto weight_ps = get_input_partial_shape(2);
96+
97+
if (hidden_ps.rank().is_static() && state_ps.rank().is_static()) {
98+
NODE_VALIDATION_CHECK(this,
99+
hidden_ps[1].compatible(state_ps[1]),
100+
"Hidden size of hidden_states (",
101+
hidden_ps[1],
102+
") and conv_state (",
103+
state_ps[1],
104+
") inputs must match.");
105+
}
106+
if (hidden_ps.rank().is_static() && weight_ps.rank().is_static()) {
107+
NODE_VALIDATION_CHECK(this,
108+
hidden_ps[1].compatible(weight_ps[0]),
109+
"Hidden size of hidden_states (",
110+
hidden_ps[1],
111+
") and weight (",
112+
weight_ps[0],
113+
") inputs must match.");
114+
}
115+
if (state_ps.rank().is_static() && weight_ps.rank().is_static()) {
116+
NODE_VALIDATION_CHECK(this,
117+
state_ps[2].compatible(weight_ps[1]),
118+
"Kernel length of conv_state (",
119+
state_ps[2],
120+
") and weight (",
121+
weight_ps[1],
122+
") inputs must match.");
123+
}
124+
if (get_input_size() == 4) {
125+
const auto bias_ps = get_input_partial_shape(3);
126+
if (bias_ps.rank().is_static() && weight_ps.rank().is_static()) {
127+
NODE_VALIDATION_CHECK(this,
128+
bias_ps[0].compatible(weight_ps[0]),
129+
"Bias length (",
130+
bias_ps[0],
131+
") must match hidden size (",
132+
weight_ps[0],
133+
").");
134+
}
135+
}
136+
137+
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
138+
set_output_type(1, get_input_element_type(1), get_input_partial_shape(1));
139+
}
140+
141+
std::shared_ptr<ov::Node> CausalConv1D::clone_with_new_inputs(const ov::OutputVector& new_args) const {
142+
return std::make_shared<CausalConv1D>(new_args, m_activation);
143+
}
144+
145+
bool CausalConv1D::visit_attributes(ov::AttributeVisitor& visitor) {
146+
visitor.on_attribute("activation", m_activation);
147+
return true;
148+
}
149+
150+
void CausalConv1D::set_activation(const std::string& activation) {
151+
m_activation = activation;
152+
}
153+
154+
const std::string& CausalConv1D::get_activation() const {
155+
return m_activation;
156+
}
157+
158+
} // namespace op
159+
} // namespace ov

src/plugins/intel_cpu/src/cpu_types.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ static const TypeToNameMap& get_type_to_name_tbl() {
267267
{"SearchSorted", Type::SearchSorted},
268268
{"LoraSubgraph", Type::LoRA},
269269
{"BatchGatherMatmul", Type::GatherMatmul},
270-
{"BatchGatherMatmulCompressed", Type::GatherMatmul}};
270+
{"BatchGatherMatmulCompressed", Type::GatherMatmul},
271+
{"CausalConv1D", Type::CausalConv1D}};
271272
return type_to_name_tbl;
272273
}
273274

@@ -403,6 +404,7 @@ std::string NameFromType(const Type type) {
403404
CASE(SegmentMax);
404405
CASE(LoRA);
405406
CASE(GatherMatmul);
407+
CASE(CausalConv1D);
406408
CASE(Unknown);
407409
}
408410
#undef CASE

src/plugins/intel_cpu/src/cpu_types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ enum class Type : uint8_t {
138138
SearchSorted,
139139
SegmentMax,
140140
LoRA,
141-
GatherMatmul
141+
GatherMatmul,
142+
CausalConv1D
142143
};
143144

144145
enum class Algorithm : uint8_t {

0 commit comments

Comments
 (0)