-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[Transformations][GPU] Constant tensor deduplication pass #29052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
81a251c
b0319c0
36561a8
bb4e391
8ebcc72
3a6a3dd
a9bb558
2714bfc
11dc772
6ea44e8
01fb8e6
52a8e48
b1e498b
42ffbc4
f14c431
286250e
68ea41b
8ddc160
37f289f
49ff00c
53f6caa
7294f31
341a535
a4b3b7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,22 @@ | ||||||
// Copyright (C) 2024 Intel Corporation | ||||||
// SPDX-License-Identifier: Apache-2.0 | ||||||
// | ||||||
|
||||||
#pragma once | ||||||
|
||||||
#include "transformations_visibility.hpp" | ||||||
#include "openvino/pass/graph_rewrite.hpp" | ||||||
|
||||||
namespace ov { | ||||||
namespace pass { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: we can use c++17 standart here: |
||||||
|
||||||
class TRANSFORMATIONS_API ConstantsReduce : public ov::pass::GraphRewrite { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. usually we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also we can change it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How so? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add new unit tests for this transformation? |
||||||
public: | ||||||
OPENVINO_GRAPH_REWRITE_RTTI("ConstantsReduce"); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use |
||||||
ConstantsReduce(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override; | ||||||
}; | ||||||
|
||||||
} // namespace pass | ||||||
} // namespace ov |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,110 @@ | ||||||
// Copyright (C) 2024 Intel Corporation | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
// SPDX-License-Identifier: Apache-2.0 | ||||||
// | ||||||
|
||||||
#include "transformations/common_optimizations/constants_reduce.hpp" | ||||||
#include "openvino/op/constant.hpp" | ||||||
#include "openvino/util/log.hpp" | ||||||
#include "itt.hpp" | ||||||
|
||||||
namespace ov { | ||||||
namespace pass { | ||||||
|
||||||
using BlobCacheKey = std::shared_ptr<ov::Node>; | ||||||
|
||||||
struct KeyHash { | ||||||
std::size_t operator()(const BlobCacheKey& key) const { | ||||||
std::size_t hash_value = 0; | ||||||
|
||||||
auto node = ov::as_type_ptr<op::v0::Constant>(key); | ||||||
dnkurek marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
auto type = node->get_output_element_type(0); | ||||||
auto shape = node->get_shape(); | ||||||
|
||||||
for (auto dim : shape) { | ||||||
hash_value ^= std::hash<size_t>{}(dim); | ||||||
} | ||||||
|
||||||
hash_value ^= std::hash<std::string>{}(type.c_type_string()); | ||||||
sshlyapn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return hash_value; | ||||||
sshlyapn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
}; | ||||||
|
||||||
struct KeyEqual { | ||||||
bool operator()(const BlobCacheKey& lhs, const BlobCacheKey& rhs) const { | ||||||
auto lhs_node = ov::as_type_ptr<op::v0::Constant>(lhs); | ||||||
auto rhs_node = ov::as_type_ptr<op::v0::Constant>(rhs); | ||||||
dnkurek marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
auto lhs_type = lhs_node->get_output_element_type(0); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part of code looks similar to this function Consider expert these function to tensor_util.hpp (part of dev API ) and re-use it. The Constant node can provide tensor view |
||||||
auto rhs_type = rhs_node->get_output_element_type(0); | ||||||
|
||||||
if (lhs_type != rhs_type) | ||||||
return false; | ||||||
|
||||||
auto lhs_shape = lhs_node->get_shape(); | ||||||
auto rhs_shape = rhs_node->get_shape(); | ||||||
|
||||||
if (lhs_shape != rhs_shape) | ||||||
return false; | ||||||
|
||||||
std::size_t lhs_size = lhs_node->get_byte_size(); | ||||||
std::size_t rhs_size = rhs_node->get_byte_size(); | ||||||
|
||||||
if (lhs_size != rhs_size) | ||||||
return false; | ||||||
|
||||||
// Retrieve buffer pointers | ||||||
const char* lhs_data = lhs_node->get_data_ptr<char>(); | ||||||
const char* rhs_data = rhs_node->get_data_ptr<char>(); | ||||||
|
||||||
if (lhs_data == rhs_data) | ||||||
dnkurek marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return true; | ||||||
|
||||||
return std::memcmp(lhs_data, rhs_data, lhs_size) == 0; | ||||||
} | ||||||
}; | ||||||
|
||||||
ConstantsReduce::ConstantsReduce() {} | ||||||
|
||||||
bool ConstantsReduce::run_on_model(const std::shared_ptr<ov::Model>& m) { | ||||||
RUN_ON_MODEL_SCOPE(ConstantsReduce); | ||||||
|
||||||
std::unordered_map<BlobCacheKey, std::shared_ptr<ov::Node>, KeyHash, KeyEqual> blobMemCache; | ||||||
|
||||||
int copies = 0; | ||||||
sshlyapn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
const std::vector<std::shared_ptr<ov::Node>> ops = m->get_ops(); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
for (auto& op : ops) { | ||||||
if (!ov::is_type<ov::op::v0::Constant>(op)) continue; | ||||||
|
||||||
auto const_node = ov::as_type_ptr<op::v0::Constant>(op); | ||||||
|
||||||
// Limit size of node reading to avoid reading large tensors | ||||||
if (const_node->get_byte_size() > 256) continue; | ||||||
sshlyapn marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to define a macro variable to make 256 more visible and informative |
||||||
|
||||||
const auto cache_key = op; | ||||||
auto bufIter = blobMemCache.find(cache_key); | ||||||
|
||||||
if (bufIter == blobMemCache.end()) { | ||||||
blobMemCache[cache_key] = op; | ||||||
} else { | ||||||
copies++; | ||||||
auto users = const_node->get_users(); | ||||||
for (auto user : users) { | ||||||
for (size_t i = 0; i < user->get_input_size(); i++) { | ||||||
if (user->input_value(i) == op->output(0)) { | ||||||
user->input(i).replace_source_output(blobMemCache[cache_key]); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
OPENVINO_DEBUG("Reduced ", copies, " constant node duplications from model"); | ||||||
|
||||||
// Return true if we have made any replacements | ||||||
return copies > 0; | ||||||
} | ||||||
|
||||||
} // namespace pass | ||||||
} // namespace ov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.