Skip to content

[GPU] working #30646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace pass {
class TRANSFORMATIONS_API CompressedGatherTransformation;
class TRANSFORMATIONS_API ConvertGatherToGatherCompressed;
class TRANSFORMATIONS_API MoveDecompressionAfterGather;
class TRANSFORMATIONS_API MarkFoldConstForGatherCompressed;

} // namespace pass
} // namespace ov
Expand Down Expand Up @@ -58,11 +59,21 @@ class ov::pass::MoveDecompressionAfterGather : public ov::pass::MatcherPass {
MoveDecompressionAfterGather();
};

/*
* MarkFoldConstForGatherCompressed enable or disable constant folding for GatherCompressed node
*/
class ov::pass::MarkFoldConstForGatherCompressed : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("MarkFoldConstForGatherCompressed");
MarkFoldConstForGatherCompressed();
};

class ov::pass::CompressedGatherTransformation : public ov::pass::GraphRewrite {
public:
OPENVINO_GRAPH_REWRITE_RTTI("CompressedGatherTransformation");
CompressedGatherTransformation() {
add_matcher<ov::pass::ConvertGatherToGatherCompressed>();
add_matcher<ov::pass::MoveDecompressionAfterGather>();
add_matcher<ov::pass::MarkFoldConstForGatherCompressed>();
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,22 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
// MOCTransformations contain StridedSliceOptimization transformation,
// so we must call SliceToStridedSlice before MOCTransformations call
REGISTER_PASS(manager, SliceToStridedSlice, true)

std::string dump_graphs_path = "graph/";
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_common_opt_0.xml",
dump_graphs_path + "ov_model_common_opt_0.bin");
}

// Disable low_precision_enabled as all plugins handle low-precision sub-graph manually
// before CommonOptimization pipeline execution
REGISTER_PASS(manager, MOCTransformations, true, false)

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_common_opt_1.xml",
dump_graphs_path + "ov_model_common_opt_1.bin");
}

// Enabling conversion of FP16 IR to legacy representation, each plugin have to disable it
// after support for FP16 IR is implemented
REGISTER_PASS(manager, ConvertCompressedOnlyToLegacy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar()
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
auto weights = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_m).get_node_shared_ptr());
auto zero_point_convert = ov::as_type_ptr<ov::op::v0::Convert>(pattern_map.at(zero_point_m).get_node_shared_ptr());
std::shared_ptr<ov::op::v0::Constant> zero_point;
if (pattern_map.count(float_zero_point_m)) {
const auto& float_zp = pattern_map.at(float_zero_point_m);
Expand Down Expand Up @@ -66,8 +67,61 @@ ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar()
float zp_value;
if (!ov::op::util::get_single_value(zero_point, zp_value))
return false;
const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
return ov::replace_node_update_name(zero_point, new_zp);

// if (zero_point->get_friendly_name() == "model.embed_tokens.zp_to_f16") {
std::cout << ">> Pass ConvertU4WeightsZeroPointToScalar : " << zero_point->get_friendly_name() << std::endl;
std::cout << " -- float_zero_point_m : " << pattern_map.count(float_zero_point_m)
<< ", zp_value : " << zp_value << ", zp_type : " << zero_point->get_element_type() << std::endl;
// }

// [TEST]
// const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
// bool result = ov::replace_node_update_name(zero_point, new_zp);
bool result;
float temp;
// if (pattern_map.count(float_zero_point_m)) {
if (true) {
const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
result = ov::replace_node_update_name(zero_point, new_zp);

std::cout << " -- Pass ConvertU4WeightsZeroPointToScalar : " << new_zp->get_friendly_name() << std::endl;
if (!ov::op::util::get_single_value(zero_point, temp))
return false;
} else {
#if 0
std::cout << " -- No convert u4 weight zp to scalar!!!!" << std::endl;
return false;
#endif

const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
result = ov::replace_node_update_name(zero_point, new_zp);

// Old
// const auto new_zp_convert = ov::op::v0::Constant::create(zero_point_convert->get_element_type(), {}, {zp_value});
// result = ov::replace_node_update_name(zero_point_convert, new_zp);

std::cout << " -- Pass ConvertU4WeightsZeroPointToScalar : " << zero_point_convert->get_friendly_name() << std::endl;
// auto zero_point_convert_const = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(zero_point_m).get_node_shared_ptr());
// if (!ov::op::util::get_single_value(zero_point_convert_const, temp))
// return false;
}
// std::cout << " -- Pass ConvertU4WeightsZeroPointToScalar : " << new_zp->get_friendly_name() << std::endl;

// float temp;
#if 0
{
if (!ov::op::util::get_single_value(zero_point, temp))
return false;

std::cout << " -- After replace zp_value : " << temp << std::endl;
if (!pattern_map.count(float_zero_point_m) || true) {
std::cout << " -- After replace zero_point : " << zero_point->get_shape() << std::endl;
// std::cout << " -- After replace zero_point : " << zero_point_convert->get_shape() << std::endl;
}
}
#endif

return result;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(subtract_m, matcher_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
#include "transformations/smart_reshape/reshape_sinking.hpp"
#include "transformations/symbolic_transformations/symbolic_optimizations.hpp"

#include "openvino/pass/serialize.hpp"

static ov::PartialShape prepare_dynamic_shape(const ov::PartialShape& shape) {
auto new_shape = ov::PartialShape::dynamic(shape.rank());
if (shape.rank().is_static())
Expand Down Expand Up @@ -148,6 +150,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
if (!m_use_shapes) {
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
}

std::string dump_graphs_path = "graph/";
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_0_init.xml",
dump_graphs_path + "ov_model_moc_trans_0_init.bin");
}

// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults
// should be performed before first ConstantFolding call.
// The passes can deteach graph branches where zero dimesion is calculated.
Expand Down Expand Up @@ -256,12 +265,29 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
ADD_MATCHER(common_fusions, AdaptivePoolToReduce)
// [TEST]
// std::string dump_graphs_path = GPU_DEBUG_VALUE_OR(config.get_dump_graphs_path(), "");
// std::string dump_graphs_path = "graph/";
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_1_before_ConvertU4WeightsZeroPointToScalar.xml",
dump_graphs_path + "ov_model_moc_trans_1.bin");
}
ADD_MATCHER(common_fusions, ConvertU4WeightsZeroPointToScalar)
common_fusions->set_name("ov::pass::CommonFusions");

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_2_after_ConvertU4WeightsZeroPointToScalar.xml",
dump_graphs_path + "ov_model_moc_trans_2.bin");
}

REGISTER_PASS(manager, BinarizeWeights)
REGISTER_PASS(manager, ConvToBinaryConv)

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_3.xml",
dump_graphs_path + "ov_model_moc_trans_3.bin");
}

auto decomp = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(decomp, BatchNormDecomposition)
ADD_MATCHER(decomp, ConvertDivideWithConstant)
Expand All @@ -270,6 +296,11 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(decomp, ConvertConvertPromoteTypes)
manager.register_pass<ov::pass::LinOpSequenceFusion>();

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_4.xml",
dump_graphs_path + "ov_model_moc_trans_4.bin");
}

auto multiply_fusions = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(multiply_fusions, ConvolutionMultiplyFusion)
ADD_MATCHER(multiply_fusions, GroupConvolutionMultiplyFusion)
Expand All @@ -280,8 +311,17 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(multiply_fusions, MultiplyConvolutionBackpropDataFusion)
ADD_MATCHER(multiply_fusions, MultiplyGroupConvolutionBackpropDataFusion)
multiply_fusions->set_name("ov::pass::MultiplyFusions");
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_4_1.xml",
dump_graphs_path + "ov_model_moc_trans_4_1.bin");
}
REGISTER_PASS(manager, ConstantFolding)

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_5.xml",
dump_graphs_path + "ov_model_moc_trans_5.bin");
}

auto fq_fusions = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
ADD_MATCHER(fq_fusions, FakeQuantizeReshapeFusion)
Expand All @@ -293,9 +333,21 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, ReverseInputChannelsFusion)
REGISTER_PASS(manager, AlignEltwiseInputRanks)
REGISTER_PASS(manager, SharedOpOptimization)

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_5.xml",
dump_graphs_path + "ov_model_moc_trans_5.bin");
}

REGISTER_PASS(manager, ConstantFolding)
REGISTER_PASS(manager, SymbolicOptimizations)
REGISTER_PASS(manager, ResolveNameCollisions, true);

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_6.xml",
dump_graphs_path + "ov_model_moc_trans_6.bin");
}

manager.run_passes(f);

if (!m_use_shapes) {
Expand All @@ -310,5 +362,11 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
}
f->validate_nodes_and_infer_types();

ov::pass::Manager tmp_manager(get_pass_config(), "tmp_MOC");
if (!dump_graphs_path.empty()) {
tmp_manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_moc_trans_7.xml",
dump_graphs_path + "ov_model_moc_trans_7.bin");
}

return false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,15 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
{ov::op::v1::Reverse::get_type_info_static(), extend_reverse_type},
};

pass::Manager manager(get_pass_config(), "KeepPrecisionSensitiveInFP32:ConvertPrecisionResult");
std::string dump_graphs_path = "graph/";
static int net_id = 0;
if (!dump_graphs_path.empty()) {
manager.register_pass<pass::Serialize>(dump_graphs_path + "ov_model_convert_precision_0_" + std::to_string(net_id) + ".xml",
dump_graphs_path + "ov_model_convert_precision_0_" + std::to_string(net_id) + ".bin");
net_id++;
}

bool is_changed = convert_precision(*this,
f,
type_to_fuse,
Expand All @@ -507,11 +516,33 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
m_convert_input_output_precision,
m_store_original_precision_as_rt_attribute);

std::cout << ">> IN ConvertPrecision, convert_precision => is_changed : " << (is_changed ? "true" : "false") << std::endl;

// to remove extra converts
if (m_keep_precision_sensitive_in_fp32) {
pass::Manager manager(get_pass_config(), "KeepPrecisionSensitiveInFP32:RemoveConverts");
// pass::Manager manager(get_pass_config(), "KeepPrecisionSensitiveInFP32:RemoveConverts");
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_convert_precision_1_" + std::to_string(net_id) + ".xml",
dump_graphs_path + "ov_model_convert_precision_1_" + std::to_string(net_id) + ".bin");
net_id++;
}
manager.register_pass<pass::EnableDecompressionConvertConstantFolding>();
// manager.run_passes(f);

if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_convert_precision_2_" + std::to_string(net_id) + ".xml",
dump_graphs_path + "ov_model_convert_precision_2_" + std::to_string(net_id) + ".bin");
net_id++;
}

manager.register_pass<pass::ConstantFolding>();
// manager.run_passes(f);
if (!dump_graphs_path.empty()) {
manager.register_pass<ov::pass::Serialize>(dump_graphs_path + "ov_model_convert_precision_3_" + std::to_string(net_id) + ".xml",
dump_graphs_path + "ov_model_convert_precision_3_" + std::to_string(net_id) + ".bin");
net_id++;
}

manager.run_passes(f);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ ov::pass::KeepConstPrecision::KeepConstPrecision(const element::TypeVector& prec
for (const auto& pattern_node : keep_const_precisions) {
if (pt_map.count(pattern_node.first)) {
auto node = pt_map.at(pattern_node.first).get_node_shared_ptr();
if (node->get_friendly_name() == "model.embed_tokens.zp_to_f16" ||
node->get_friendly_name() == "model.embed_tokens.zp_const") {
std::cout << ">> In KeepConstPrecisionm, " << node->get_friendly_name() << std::endl;
}

if (ov::as_type_ptr<v0::Constant>(node) && check_precision(precisions)(node->output(0))) {
if (pattern_node.second) {
ov::disable_keep_const_precision(node);
Expand Down
Loading
Loading