Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ TYPE_PRINTER(std::size_t)
#ifndef ONEAPI_MAKE_VERSION
/// @brief Generates generic 'oneAPI' API versions
# define ONEAPI_MAKE_VERSION(_major, _minor) ((_major << 16) | (_minor & 0x0000ffff))

/// @brief extract 'oneAPI' API major version
# define ONEAPI_VERSION_MAJOR(_version) ((_version) >> 16)

/// @brief extract 'oneAPI' API minor version
# define ONEAPI_VERSION_MINOR(_version) ((_version) & 0x0000ffff)

#endif // ONEAPI_MAKE_VERSION

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ DEFINE_OPT(NPUW_CWAI, bool, false, npuw::partitioning::cwai, RunTime);
DEFINE_OPT(NPUW_DQ, bool, false, npuw::partitioning::dyn_quant, RunTime);
DEFINE_OPT(NPUW_DQ_FULL, bool, true, npuw::partitioning::dyn_quant_full, RunTime);
DEFINE_OPT(NPUW_PMM, std::string, "2", npuw::partitioning::par_matmul_merge_dims, RunTime);
DEFINE_OPT(NPUW_MM_GATED, bool, true, npuw::partitioning::matmul_gate_preserve_constants, RunTime);
DEFINE_OPT(NPUW_SLICE_OUT, bool, false, npuw::partitioning::slice_out, RunTime);
DEFINE_OPT(NPUW_HOST_GATHER, bool, true, npuw::partitioning::host_gather, RunTime);
DEFINE_OPT(NPUW_SPATIAL, bool, false, npuw::partitioning::spatial, RunTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,15 @@ static constexpr ov::Property<bool> dyn_quant_full{"NPUW_DQ_FULL"};
*/
static constexpr ov::Property<std::string> par_matmul_merge_dims{"NPUW_PMM"};

/**
* @brief
* Type: bool.
* whether to preserve constants for gated version of matmul
* on some version of compiler - might produce incorrect results when enabled
* Default value: YES
*/
static constexpr ov::Property<bool> matmul_gate_preserve_constants{"NPUW_MM_GATED"};

/**
* @brief
* Type: bool.
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/al/src/config/npuw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void intel_npu::registerNPUWOptions(OptionsDesc& desc) {
desc.add<NPUW_DQ>();
desc.add<NPUW_DQ_FULL>();
desc.add<NPUW_PMM>();
desc.add<NPUW_MM_GATED>();
desc.add<NPUW_SLICE_OUT>();
desc.add<NPUW_SPATIAL>();
desc.add<NPUW_SPATIAL_NWAY>();
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_npu/src/plugin/npuw/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,10 @@ bool ov::npuw::CompiledModel::should_use_quantized_host_gather(const std::shared
std::vector<CPtr> to_keep;

ov::pass::GraphRewrite rewr2;
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(to_keep));
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulSymm>(std::ref(to_keep));
ctx.mm_gate = m_cfg.get<::intel_npu::NPUW_MM_GATED>();

rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(ctx), std::ref(to_keep));
rewr2.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulFP8>(std::ref(ctx), std::ref(to_keep));
rewr2.run_on_model(model);
// FIXME: since 3-model pipeline is the default option, the tail will be separate,
// so we need to match either head or tail pattern here for host gather quantized feature to work.
Expand Down Expand Up @@ -2506,6 +2508,7 @@ void ov::npuw::CompiledModel::implement_properties() {
BIND(npuw::partitioning::dyn_quant, NPUW_DQ),
BIND(npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL),
BIND(npuw::partitioning::par_matmul_merge_dims, NPUW_PMM),
BIND(npuw::partitioning::matmul_gate_preserve_constants, NPUW_MM_GATED),
BIND(npuw::partitioning::slice_out, NPUW_SLICE_OUT),
BIND(npuw::partitioning::spatial, NPUW_SPATIAL),
BIND(npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY),
Expand Down
21 changes: 20 additions & 1 deletion src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ struct NPUDesc {
std::string arch;
int64_t max_tiles = 0;
bool compiler_dq = false;
bool compiler_matmul_gate = false;
int64_t compiler_ver = 0;
bool support_flash_attention_tile = false;
};
Expand Down Expand Up @@ -1199,6 +1200,19 @@ std::optional<NPUDesc> extract_npu_descriptor(const std::shared_ptr<const ov::IP
ov::AnyMap{{ov::intel_npu::compiler_type.name(), target_compiler_type}})
.as<int64_t>();
}
LOG_INFO("Compiler version: " << ONEAPI_VERSION_MAJOR(desc.compiler_ver) << "."
<< ONEAPI_VERSION_MINOR(desc.compiler_ver));

constexpr std::string_view compiler_gate_support_msg =
"Compiler: accurate gated matmul (MatMul -> Divide -> Tanh -> Multiply -> Result) : ";

if (desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 28)) {
// accuracy for gated matmul fixed at 7.28
desc.compiler_matmul_gate = true;
LOG_INFO(compiler_gate_support_msg << "supported");
} else {
LOG_WARN(compiler_gate_support_msg << "unsupported");
}

if (desc.arch == "5010" && desc.compiler_ver >= ONEAPI_MAKE_VERSION(7, 29)) {
// Flash attention tile is supported starting from compiler version 7.29 on NPU5010
Expand Down Expand Up @@ -1247,6 +1261,11 @@ ov::AnyMap get_baseline_common_config(const std::optional<NPUDesc>& npudesc) {
config.erase("NPUW_DCOFF_TYPE");
config.erase("NPUW_DCOFF_SCALE");
}

// default version is ON - while for older compiler it might be turned off
if (npudesc.has_value()) {
config.emplace("NPUW_MM_GATED", (npudesc->compiler_matmul_gate ? "YES" : "NO"));
}
return config;
}

Expand Down Expand Up @@ -1878,7 +1897,7 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
if (!m_is_embedding) {
if (!m_use_chunk_prefill) {
// TODO: sometimes it is ok if we cannot find any empty inputs or not?
NPUW_ASSERT(remove_empty_kv_inputs(prefill_model));
remove_empty_kv_inputs(prefill_model);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is especially needed and a leftover for fp8-cb4 work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AsyaPronina indeed i found a source of the problem - in fp8 patterns we used for redirect - that one also need to be checked when empty kv inputs removed with concats etc in prefill - please have a look on actual implementation

} else {
LOG_DEBUG("Don't remove input key/values from prefill model.");
LOG_DEBUG("Ask prefill model to output key/values for prefill chunk size tokens.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1494,9 +1494,12 @@ void Partitioner::saveTailDictConstants(const std::string& func_name) {
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
std::vector<CPtr> to_keep;

ov::npuw::patterns::opt::Context ctx;
ctx.mm_gate = cfg.get<::intel_npu::NPUW_MM_GATED>();

ov::pass::GraphRewrite rewr;
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(to_keep));
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulSymm>(std::ref(to_keep));
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulAsymm>(std::ref(ctx), std::ref(to_keep));
rewr.add_matcher<ov::npuw::patterns::opt::PreserveConstDictMatMulFP8>(std::ref(ctx), std::ref(to_keep));
rewr.run_on_model(model_group.front());

for (auto&& const_to_keep : to_keep) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,8 @@ CompressDictMatMulf32::CompressDictMatMulf32(Context::Ref ctx) {
// Const(S) ---------------------> Multiply -> to(f32) -> MatMul -> Result
// ???(Act) -------------------------------------------->

PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatMulAsymm::Results to_keep) {
PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(Context::Ref ctx,
PreserveConstDictMatMulAsymm::Results to_keep) {
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
auto qzerop = opp::wrap_type<ov::op::v0::Constant>();
Expand All @@ -1930,7 +1931,21 @@ PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatM
auto qcvtm = opp::wrap_type<ov::op::v0::Convert>({qmuls});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtm});
auto qres = opp::wrap_type<ov::op::v0::Result>({qmm});
std::shared_ptr<Node> qres;

// MatMul -> Divide -> Tanh -> Multiply -> Result
if (ctx.get().mm_gate) {
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({qmm, opp::any_input()});
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});

auto matmul_or =
std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{qmm->output(0), matmul_multiply->output(0)});

qres = opp::wrap_type<ov::op::v0::Result>({matmul_or});
} else {
qres = opp::wrap_type<ov::op::v0::Result>({qmm});
}

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
Expand Down Expand Up @@ -1964,14 +1979,28 @@ PreserveConstDictMatMulAsymm::PreserveConstDictMatMulAsymm(PreserveConstDictMatM
// Const(S) ----------------> Multiply -> MatMul -> Result
// ???(Act) ---------------------------->

PreserveConstDictMatMulSymm::PreserveConstDictMatMulSymm(PreserveConstDictMatMulSymm::Results to_keep) {
PreserveConstDictMatMulFP8::PreserveConstDictMatMulFP8(Context::Ref ctx, PreserveConstDictMatMulFP8::Results to_keep) {
auto qweight = opp::wrap_type<ov::op::v0::Constant>();
auto qcoeff = opp::wrap_type<ov::op::v0::Constant>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto optional_kvt = opp::optional<ov::op::v0::Convert>({qmuls});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qmuls});
auto qres = opp::wrap_type<ov::op::v0::Result>({qmm});
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, optional_kvt});
std::shared_ptr<Node> qres;
// // MatMul -> Divide -> Tanh -> Multiply -> Result
if (ctx.get().mm_gate) {
auto div = opp::wrap_type<ov::op::v1::Multiply, ov::op::v1::Divide>({qmm, opp::any_input()});
auto tanh = opp::wrap_type<ov::op::v0::Tanh>({div});
auto matmul_multiply = opp::wrap_type<ov::op::v1::Multiply>({tanh, opp::any_input()});

auto matmul_or =
std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{qmm->output(0), matmul_multiply->output(0)});

qres = opp::wrap_type<ov::op::v0::Result>({matmul_or});
} else {
qres = opp::wrap_type<ov::op::v0::Result>({qmm});
}

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
Expand All @@ -1997,7 +2026,7 @@ PreserveConstDictMatMulSymm::PreserveConstDictMatMulSymm(PreserveConstDictMatMul
}
return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(qres, "OptPreserveConstDictMatMulSymm"), std::move(callback));
register_matcher(std::make_shared<opp::Matcher>(qres, "OptPreserveConstDictMatMulFP8"), std::move(callback));
}

SliceLastMatmul::SliceLastMatmul() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct Context {
std::string pmm_dims;
bool is_spatial = false;
bool mm_dq_full = true;
bool mm_gate = false;

using PPtr = std::shared_ptr<ov::op::v0::Parameter>;
using NPtr = std::shared_ptr<ov::Node>;
Expand Down Expand Up @@ -229,17 +230,17 @@ class PreserveConstDictMatMulAsymm : public ov::pass::MatcherPass {
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
using Results = std::reference_wrapper<std::vector<CPtr>>;

PreserveConstDictMatMulAsymm(Results to_keep);
PreserveConstDictMatMulAsymm(Context::Ref ctx, Results to_keep);
};

class PreserveConstDictMatMulSymm : public ov::pass::MatcherPass {
class PreserveConstDictMatMulFP8 : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::PreserveConstDictMatMulSymm");
OPENVINO_MATCHER_PASS_RTTI("npuw::patterns::opt::PreserveConstDictMatMulFP8");

using CPtr = std::shared_ptr<ov::op::v0::Constant>;
using Results = std::reference_wrapper<std::vector<CPtr>>;

PreserveConstDictMatMulSymm(Results to_keep);
PreserveConstDictMatMulFP8(Context::Ref ctx, Results to_keep);
};

// Slice last Matmul
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ void init_config(const IEngineBackend* backend, OptionsDesc& options, FilteredCo
REGISTER_OPTION(NPUW_DQ);
REGISTER_OPTION(NPUW_DQ_FULL);
REGISTER_OPTION(NPUW_PMM);
REGISTER_OPTION(NPUW_MM_GATED);
REGISTER_OPTION(NPUW_SLICE_OUT);
REGISTER_OPTION(NPUW_SPATIAL);
REGISTER_OPTION(NPUW_SPATIAL_NWAY);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_npu/src/plugin/src/properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ void Properties::registerPluginProperties() {
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::dyn_quant, NPUW_DQ);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::dyn_quant_full, NPUW_DQ_FULL);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::par_matmul_merge_dims, NPUW_PMM);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::matmul_gate_preserve_constants, NPUW_MM_GATED);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::slice_out, NPUW_SLICE_OUT);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::spatial, NPUW_SPATIAL);
TRY_REGISTER_SIMPLE_PROPERTY(ov::intel_npu::npuw::partitioning::spatial_nway, NPUW_SPATIAL_NWAY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class PropertiesManagerTests : public ov::test::behavior::OVPluginTestBase,
REGISTER_OPTION(NPUW_DQ);
REGISTER_OPTION(NPUW_DQ_FULL);
REGISTER_OPTION(NPUW_PMM);
REGISTER_OPTION(NPUW_MM_GATED);
REGISTER_OPTION(NPUW_SLICE_OUT);
REGISTER_OPTION(NPUW_SPATIAL);
REGISTER_OPTION(NPUW_SPATIAL_NWAY);
Expand Down
Loading