Skip to content

Commit 726fdaf

Browse files
esmirnodmatveev
andauthored
[NPUW] gemma-2 patterns added to preserve tail constants matcher (#32465)
### Details: - gemma2-sym works fine after weight preserved as constants. Performance matches expectations <img width="370" height="640" alt="image" src="https://github.com/user-attachments/assets/8b2c7f39-9c32-4db8-af0f-cbe6429366fd" /> - gemma2-asym works but accuracy issues found. Performance matches expectations too <img width="402" height="721" alt="image" src="https://github.com/user-attachments/assets/464d02b2-fd7a-4d85-a1fb-72f77a944e12" /> - as of compiler version 7.28 no accuracy issues observed anymore - this patch behavior by default restricted to this compiler version only. ### Tickets: - E-189635 --------- Co-authored-by: Dmitry Matveev <dmitry.matveev@intel.com>
1 parent 7e5c0d9 commit 726fdaf

File tree

13 files changed

+137
-43
lines changed

13 files changed

+137
-43
lines changed

src/plugins/intel_npu/src/al/include/intel_npu/config/config.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ TYPE_PRINTER(std::size_t)
5757
#ifndef ONEAPI_MAKE_VERSION
5858
/// @brief Generates generic 'oneAPI' API versions
5959
# define ONEAPI_MAKE_VERSION(_major, _minor) ((_major << 16) | (_minor & 0x0000ffff))
60+
61+
/// @brief extract 'oneAPI' API major version
62+
# define ONEAPI_VERSION_MAJOR(_version) ((_version) >> 16)
63+
64+
/// @brief extract 'oneAPI' API minor version
65+
# define ONEAPI_VERSION_MINOR(_version) ((_version) & 0x0000ffff)
66+
6067
#endif // ONEAPI_MAKE_VERSION
6168

6269
//

src/plugins/intel_npu/src/al/include/intel_npu/config/npuw.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, RunTime);
112112
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, RunTime);
113113
DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, RunTime);
114114
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, RunTime);
115+
DEFINE_OPT(NPUW_MM_GATED, bool, true, npuw::partitioning::matmul_gate_preserve_constants, RunTime);
115116
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, RunTime);
116117
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, RunTime);
117118
DEFINE_OPT(NPUW_SPATIAL, bool, false, npuw::partitioning::spatial, RunTime);

src/plugins/intel_npu/src/al/include/intel_npu/npuw_private_properties.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,15 @@ static constexpr ov::Property<bool> dyn_quant_full{"NPUW_DQ_FULL"};
216216
*/
217217
static constexpr ov::Property<std::string> par_matmul_merge_dims{"NPUW_PMM"};
218218

219+
/**
220+
* @brief
221+
* Type: bool.
222+
* whether to preserve constants for gated version of matmul
223+
* on some version of compiler - might produce incorrect results when enabled
224+
* Default value: YES
225+
*/
226+
static constexpr ov::Property<bool> matmul_gate_preserve_constants{"NPUW_MM_GATED"};
227+
219228
/**
220229
* @brief
221230
* Type: bool.

src/plugins/intel_npu/src/al/src/config/npuw.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
2929
desc.add<NPUW_DQ>();
3030
desc.add<NPUW_DQ_FULL>();
3131
desc.add<NPUW_PMM>();
32+
desc.add<NPUW_MM_GATED>();
3233
desc.add<NPUW_SLICE_OUT>();
3334
desc.add<NPUW_SPATIAL>();
3435
desc.add<NPUW_SPATIAL_NWAY>();

src/plugins/intel_npu/src/plugin/include/properties.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class Properties final {
172172
ov::intel_npu::npuw::partitioning::dyn_quant.name(),
173173
ov::intel_npu::npuw::partitioning::dyn_quant_full.name(),
174174
ov::intel_npu::npuw::partitioning::par_matmul_merge_dims.name(),
175+
ov::intel_npu::npuw::partitioning::matmul_gate_preserve_constants.name(),
175176
ov::intel_npu::npuw::partitioning::slice_out.name(),
176177
ov::intel_npu::npuw::partitioning::spatial.name(),
177178
ov::intel_npu::npuw::partitioning::spatial_nway.name(),

src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,10 @@ bool ov::npuw::CompiledModel::should_use_quantized_host_gather(const std::shared
635635
std::vector<CPtr> to_keep;
636636

637637
ov::pass::GraphRewrite rewr2;
638-
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(to_keep));
639-
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulSymm>(std::ref(to_keep));
638+
ctx.mm_gate = m_cfg.get<::intel_npu::NPUW_MM_GATED>();
639+
640+
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(ctx), std::ref(to_keep));
641+
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulFP8>(std::ref(ctx), std::ref(to_keep));
640642
rewr2.run_on_model(model);
641643
// FIXME: since 3-model pipeline is the default option, the tail will be separate,
642644
// so we need to match either head or tail pattern here for host gather quantized feature to work.
@@ -2506,6 +2508,7 @@ void ov::npuw::CompiledModel::implement_properties() {
25062508
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
25072509
BIND(npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL),
25082510
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
2511+
BIND(npuw::partitioning::matmul_gate_preserve_constants, NPUW_MM_GATED),
25092512
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
25102513
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
25112514
BIND(npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY),

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,23 @@
4343

4444
namespace opp = ov::pass::pattern;
4545

46+
// specific function that match subgraph appeared as result of lpt transformations
47+
auto match_down_up_convert_subgraph_after_lpt = [](const ov::Output<ov::Node>& input) {
48+
auto upconvert = opp::wrap_type<ov::op::v0::Convert>({input}, opp::type_matches(ov::element::f32));
49+
50+
auto upscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0));
51+
auto upmul = opp::wrap_type<ov::op::v1::Multiply>({upconvert, upscale});
52+
53+
auto downscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0));
54+
auto downmul = opp::wrap_type<ov::op::v1::Multiply>({upmul, downscale});
55+
56+
auto downconvert =
57+
opp::wrap_type<ov::op::v0::Convert>({downmul},
58+
opp::type_matches_any({ov::element::f8e4m3, ov::element::f8e5m2}));
59+
60+
return downconvert;
61+
};
62+
4663
class RemoveEmptyKVTensors : public ov::pass::MatcherPass {
4764
public:
4865
OPENVINO_MATCHER_PASS_RTTI("npuw::LLMCompiledModel::RemoveEmptyKVTensors");
@@ -54,7 +71,10 @@ class RemoveEmptyKVTensors : public ov::pass::MatcherPass {
5471

5572
RemoveEmptyKVTensors(Context::Ref ctx) {
5673
auto param = opp::wrap_type<ov::op::v0::Parameter>();
57-
auto concat = opp::wrap_type<ov::op::v0::Concat>({param, opp::any_input()});
74+
auto param_or =
75+
std::make_shared<opp::op::Or>(ov::OutputVector{param, match_down_up_convert_subgraph_after_lpt(param)});
76+
77+
auto concat = opp::wrap_type<ov::op::v0::Concat>({param_or, opp::any_input()});
5878

5979
auto callback = [=](opp::Matcher& m) {
6080
auto& node_to_output = m.get_pattern_value_map();
@@ -63,15 +83,27 @@ class RemoveEmptyKVTensors : public ov::pass::MatcherPass {
6383

6484
ctx.get().old_params.push_back(matched_param);
6585

66-
auto users = matched_param->get_users();
67-
if (users.size() == 2u) {
68-
auto shapeof_node = ov::is_type<ov::op::v3::ShapeOf>(users[0]) ? users[0] : users[1];
69-
NPUW_ASSERT(ov::is_type<ov::op::v3::ShapeOf>(shapeof_node));
70-
auto cst_node =
71-
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, matched_param->get_shape());
72-
ov::replace_node(shapeof_node, cst_node);
73-
} else {
74-
NPUW_ASSERT(users.size() == 1u);
86+
// Use concat's first input source node to find ShapeOf users.
87+
// This works universally for both plain parameter and down_up_convert subgraph cases,
88+
// because in the subgraph case matched_param->get_users() would return the Convert
89+
// node (first node of the subgraph), not the ShapeOf.
90+
auto concat_input0_node = matched_node_concat->input(0).get_source_output().get_node_shared_ptr();
91+
auto users = concat_input0_node->get_users();
92+
93+
// In subgraph case the parameter itself may also have a ShapeOf user,
94+
// so check both the concat input node and the parameter.
95+
if (concat_input0_node != matched_param) {
96+
auto param_users = matched_param->get_users();
97+
users.insert(users.end(), param_users.begin(), param_users.end());
98+
}
99+
100+
// Find and replace ShapeOf nodes with constants
101+
for (auto& user : users) {
102+
if (ov::is_type<ov::op::v3::ShapeOf>(user)) {
103+
auto cst_node =
104+
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, matched_param->get_shape());
105+
ov::replace_node(user, cst_node);
106+
}
75107
}
76108

77109
// Redirect second concat input to every node which reads from concat
@@ -323,22 +355,6 @@ class GroupQueryAttentionDecomposition : public ov::pass::MatcherPass {
323355
class RedirectNewKvToOutput : public ov::pass::MatcherPass {
324356
public:
325357
RedirectNewKvToOutput() {
326-
auto match_down_up_convert_subgraph = [](const ov::Output<ov::Node>& input) {
327-
auto upconvert = opp::wrap_type<ov::op::v0::Convert>({input}, opp::type_matches(ov::element::f32));
328-
329-
auto upscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0));
330-
auto upmul = opp::wrap_type<ov::op::v1::Multiply>({upconvert, upscale});
331-
332-
auto downscale = opp::wrap_type<ov::op::v0::Constant>(opp::rank_equals(0));
333-
auto downmul = opp::wrap_type<ov::op::v1::Multiply>({upmul, downscale});
334-
335-
auto downconvert =
336-
opp::wrap_type<ov::op::v0::Convert>({downmul},
337-
opp::type_matches_any({ov::element::f8e4m3, ov::element::f8e5m2}));
338-
339-
return downconvert;
340-
};
341-
342358
// example of fp8 inputs to concat
343359
// input0 : float8e4m3[1,32,1151,96]
344360
// input1 : float8e4m3[1,32,1,96]
@@ -348,13 +364,13 @@ class RedirectNewKvToOutput : public ov::pass::MatcherPass {
348364
// TODO: this matcher logic better to cover with unit-tests
349365
auto input0 = opp::wrap_type<ov::op::v0::Parameter>();
350366
auto input0_or =
351-
std::make_shared<opp::op::Or>(ov::OutputVector{input0, match_down_up_convert_subgraph(input0)});
367+
std::make_shared<opp::op::Or>(ov::OutputVector{input0, match_down_up_convert_subgraph_after_lpt(input0)});
352368

353369
auto input1 = opp::any_input();
354370

355371
auto kv_concat = opp::wrap_type<ov::op::v0::Concat>({input0_or, input1});
356372
auto result1 = opp::wrap_type<ov::op::v0::Result>(kv_concat);
357-
auto result2 = opp::wrap_type<ov::op::v0::Result>(match_down_up_convert_subgraph(kv_concat));
373+
auto result2 = opp::wrap_type<ov::op::v0::Result>(match_down_up_convert_subgraph_after_lpt(kv_concat));
358374

359375
auto result_or = std::make_shared<opp::op::Or>(ov::OutputVector{result1, result2});
360376

@@ -1175,6 +1191,7 @@ struct NPUDesc {
11751191
std::string arch;
11761192
int64_t max_tiles = 0;
11771193
bool compiler_dq = false;
1194+
bool compiler_matmul_gate = false;
11781195
int64_t compiler_ver = 0;
11791196
bool support_flash_attention_tile = false;
11801197
};
@@ -1212,6 +1229,19 @@ std::optional<NPUDesc> extract_npu_descriptor(const std::shared_ptr<const ov::IP
12121229
ov::AnyMap{{ov::intel_npu::compiler_type.name(), target_compiler_type}})
12131230
.as<int64_t>();
12141231
}
1232+
LOG_INFO("Compiler version: " << ONEAPI_VERSION_MAJOR(desc.compiler_ver) << "."
1233+
<< ONEAPI_VERSION_MINOR(desc.compiler_ver));
1234+
1235+
constexpr std::string_view compiler_gate_support_msg =
1236+
"Compiler: accurate gated matmul (MatMul -> Divide -> Tanh -> Multiply -> Result) : ";
1237+
1238+
if (desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 28)) {
1239+
// accuracy for gated matmul fixed at 7.28
1240+
desc.compiler_matmul_gate = true;
1241+
LOG_INFO(compiler_gate_support_msg << "supported");
1242+
} else {
1243+
LOG_WARN(compiler_gate_support_msg << "unsupported");
1244+
}
12151245

12161246
if (desc.arch == "5010" && desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 29)) {
12171247
// Flash attention tile is supported starting from compiler version 7.29 on NPU5010
@@ -1260,6 +1290,13 @@ ov::AnyMap get_baseline_common_config(const std::optional<NPUDesc>& npudesc) {
12601290
config.erase("NPUW_DCOFF_TYPE");
12611291
config.erase("NPUW_DCOFF_SCALE");
12621292
}
1293+
1294+
// default value is ON
1295+
// for compiler versions >= 7.28 value is ON
1296+
// for other compiler versions value is OFF
1297+
if (npudesc.has_value()) {
1298+
config.emplace("NPUW_MM_GATED", (npudesc->compiler_matmul_gate ? "YES" : "NO"));
1299+
}
12631300
return config;
12641301
}
12651302

@@ -1896,7 +1933,6 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
18961933

18971934
if (!m_is_embedding) {
18981935
if (!m_use_chunk_prefill) {
1899-
// TODO: sometimes it is ok if we cannot find any empty inputs or not?
19001936
NPUW_ASSERT(remove_empty_kv_inputs(prefill_model));
19011937
} else {
19021938
LOG_DEBUG("Don't remove input key/values from prefill model.");

src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,9 +1494,12 @@ void Partitioner::saveTailDictConstants(const std::string& func_name) {
14941494
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
14951495
std::vector<CPtr> to_keep;
14961496

1497+
ov::npuw::patterns::opt::Context ctx;
1498+
ctx.mm_gate = cfg.get<::intel_npu::NPUW_MM_GATED>();
1499+
14971500
ov::pass::GraphRewrite rewr;
1498-
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(to_keep));
1499-
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulSymm>(std::ref(to_keep));
1501+
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(ctx), std::ref(to_keep));
1502+
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulFP8>(std::ref(ctx), std::ref(to_keep));
15001503
rewr.run_on_model(model_group.front());
15011504

15021505
for (auto&& const_to_keep : to_keep) {

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,8 @@ CompressDictMatMulf32::CompressDictMatMulf32(Context::Ref ctx) {
19191919
// Const(S) ---------------------> Multiply -> to(f32) -> MatMul -> Result
19201920
// ???(Act) -------------------------------------------->
19211921

1922-
PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatMulAsymm::Results to_keep) {
1922+
PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(Context::Ref ctx,
1923+
PreserveConstDictMatMulAsymm::Results to_keep) {
19231924
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
19241925
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
19251926
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
@@ -1930,7 +1931,21 @@ PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatM
19301931
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});
19311932
auto qmmi = opp::any_input();
19321933
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtm});
1933-
auto qres = opp::wrap_type<ov::op::v0::Result>({qmm});
1934+
std::shared_ptr<Node> qres;
1935+
1936+
// MatMul -> Divide -> Tanh -> Multiply -> Result
1937+
if (ctx.get().mm_gate) {
1938+
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({qmm, opp::any_input()});
1939+
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
1940+
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
1941+
1942+
auto matmul_or =
1943+
std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{qmm->output(0), matmul_multiply->output(0)});
1944+
1945+
qres = opp::wrap_type<ov::op::v0::Result>({matmul_or});
1946+
} else {
1947+
qres = opp::wrap_type<ov::op::v0::Result>({qmm});
1948+
}
19341949

19351950
// Note: Use [=] to make sure the above objects stay alive in the callback
19361951
auto callback = [=](ov::pass::pattern::Matcher& m) {
@@ -1964,14 +1979,28 @@ PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatM
19641979
// Const(S) ----------------> Multiply -> MatMul -> Result
19651980
// ???(Act) ---------------------------->
19661981

1967-
PreserveConstDictMatMulSymm::PreserveConstDictMatMulSymm(PreserveConstDictMatMulSymm::Results to_keep) {
1982+
PreserveConstDictMatMulFP8::PreserveConstDictMatMulFP8(Context::Ref ctx, PreserveConstDictMatMulFP8::Results to_keep) {
19681983
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
19691984
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
19701985
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
19711986
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
1987+
auto optional_cvt = opp::optional<ov::op::v0::Convert>({qmuls});
19721988
auto qmmi = opp::any_input();
1973-
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qmuls});
1974-
auto qres = opp::wrap_type<ov::op::v0::Result>({qmm});
1989+
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, optional_cvt});
1990+
std::shared_ptr<Node> qres;
1991+
// // MatMul -> Divide -> Tanh -> Multiply -> Result
1992+
if (ctx.get().mm_gate) {
1993+
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({qmm, opp::any_input()});
1994+
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
1995+
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});
1996+
1997+
auto matmul_or =
1998+
std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{qmm->output(0), matmul_multiply->output(0)});
1999+
2000+
qres = opp::wrap_type<ov::op::v0::Result>({matmul_or});
2001+
} else {
2002+
qres = opp::wrap_type<ov::op::v0::Result>({qmm});
2003+
}
19752004

19762005
// Note: Use [=] to make sure the above objects stay alive in the callback
19772006
auto callback = [=](ov::pass::pattern::Matcher& m) {
@@ -1997,7 +2026,7 @@ PreserveConstDictMatMulSymm::PreserveConstDictMatMulSymm(PreserveConstDictMatMul
19972026
}
19982027
return false; // root hasn't changed
19992028
};
2000-
register_matcher(std::make_shared<opp::Matcher>(qres, "OptPreserveConstDictMatMulSymm"), std::move(callback));
2029+
register_matcher(std::make_shared<opp::Matcher>(qres, "OptPreserveConstDictMatMulFP8"), std::move(callback));
20012030
}
20022031

20032032
SliceLastMatmul::SliceLastMatmul() {

src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct Context {
2323
std::string pmm_dims;
2424
bool is_spatial = false;
2525
bool mm_dq_full = true;
26+
bool mm_gate = false;
2627

2728
using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
2829
using NPtr = std::shared_ptr<ov::Node>;
@@ -229,17 +230,17 @@ class PreserveConstDictMatMulAsymm : public ov::pass::MatcherPass {
229230
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
230231
using Results = std::reference_wrapper<std::vector<CPtr>>;
231232

232-
PreserveConstDictMatMulAsymm(Results to_keep);
233+
PreserveConstDictMatMulAsymm(Context::Ref ctx, Results to_keep);
233234
};
234235

235-
class PreserveConstDictMatMulSymm : public ov::pass::MatcherPass {
236+
class PreserveConstDictMatMulFP8 : public ov::pass::MatcherPass {
236237
public:
237-
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::PreserveConstDictMatMulSymm");
238+
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::PreserveConstDictMatMulFP8");
238239

239240
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
240241
using Results = std::reference_wrapper<std::vector<CPtr>>;
241242

242-
PreserveConstDictMatMulSymm(Results to_keep);
243+
PreserveConstDictMatMulFP8(Context::Ref ctx, Results to_keep);
243244
};
244245

245246
// Slice last Matmul

0 commit comments

Comments
 (0)