Skip to content

Hetero support continuous batching #30371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
9 changes: 8 additions & 1 deletion src/plugins/hetero/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "openvino/runtime/properties.hpp"
#include "openvino/util/common_util.hpp"
#include "properties.hpp"
#include "remote_context.hpp"

ov::hetero::Plugin::Plugin() {
set_device_name("HETERO");
Expand All @@ -35,6 +36,7 @@ std::shared_ptr<ov::ICompiledModel> ov::hetero::Plugin::compile_model(const std:

auto config = Configuration{properties, m_cfg};
auto compiled_model = std::make_shared<CompiledModel>(model->clone(), shared_from_this(), config);
execution_devices = compiled_model->get_property("EXECUTION_DEVICES").as<std::vector<std::string>>();
return compiled_model;
}

Expand Down Expand Up @@ -329,5 +331,10 @@ ov::SoPtr<ov::IRemoteContext> ov::hetero::Plugin::create_context(const ov::AnyMa
}

ov::SoPtr<ov::IRemoteContext> ov::hetero::Plugin::get_default_context(const ov::AnyMap& remote_properties) const {
OPENVINO_NOT_IMPLEMENTED;
std::map<std::string, ov::SoPtr<ov::IRemoteContext>> contexts_for_tp;
OPENVINO_ASSERT(execution_devices.size() >= 1, "There is no execution devices in HETERO.");
for (auto device : execution_devices) {
contexts_for_tp.insert({device, get_core()->get_default_context(device)});
}
return std::make_shared<ov::hetero::RemoteContext>(contexts_for_tp);
}
2 changes: 2 additions & 0 deletions src/plugins/hetero/src/plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class Plugin : public ov::IPlugin {
Configuration m_cfg;

mutable size_t independent_submodel_size = 0;

mutable std::vector<std::string> execution_devices;
};

} // namespace hetero
Expand Down
44 changes: 44 additions & 0 deletions src/plugins/hetero/src/remote_context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "remote_context.hpp"

#include <memory>

#include "openvino/runtime/make_tensor.hpp"
#include "remote_tensor.hpp"

namespace ov {
namespace hetero {

RemoteContext::RemoteContext(std::map<std::string, ov::SoPtr<ov::IRemoteContext>> contexts) {
m_contexts = contexts;
}

const ov::AnyMap& RemoteContext::get_property() const {
return m_contexts.begin()->second->get_property();
}

std::shared_ptr<RemoteContext> RemoteContext::get_this_shared_ptr() {
return std::static_pointer_cast<RemoteContext>(shared_from_this());
}

ov::SoPtr<ov::IRemoteTensor> RemoteContext::create_tensor(const ov::element::Type& type,
const ov::Shape& shape,
const ov::AnyMap& params) {
std::vector<ov::SoPtr<ov::IRemoteTensor>> tensors;
for (auto& item : m_contexts) {
auto a = item.second->create_tensor(type, shape, params);
tensors.emplace_back(a);
}
return std::make_shared<ov::hetero::RemoteTensor>(get_this_shared_ptr(), tensors);
}

const std::string& RemoteContext::get_device_name() const {
static const std::string name = "HETERO";
return name;
}

} // namespace hetero
} // namespace ov
32 changes: 32 additions & 0 deletions src/plugins/hetero/src/remote_context.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>

#include "openvino/runtime/iremote_context.hpp"

namespace ov {
namespace hetero {
class RemoteContext : public ov::IRemoteContext {
public:
using Ptr = std::shared_ptr<RemoteContext>;

RemoteContext(std::map<std::string, ov::SoPtr<ov::IRemoteContext>> contexts);

const std::string& get_device_name() const override;
const ov::AnyMap& get_property() const override;

ov::SoPtr<ov::IRemoteTensor> create_tensor(const ov::element::Type& type,
const ov::Shape& shape,
const ov::AnyMap& params) override;

private:
std::shared_ptr<RemoteContext> get_this_shared_ptr();
std::map<std::string, ov::SoPtr<ov::IRemoteContext>> m_contexts;
};

} // namespace hetero
} // namespace ov
102 changes: 102 additions & 0 deletions src/plugins/hetero/src/remote_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "remote_tensor.hpp"

namespace ov {
namespace hetero {

RemoteTensor::RemoteTensor(std::shared_ptr<RemoteContext> context, std::vector<ov::SoPtr<ov::IRemoteTensor>> tensors)
: m_context(context),
m_ordered_tensor(tensors) {
for (auto& tensor : tensors) {
auto remote_tensor = std::dynamic_pointer_cast<ov::IRemoteTensor>(tensor._ptr);
m_remote_tensors.emplace_back(remote_tensor);
auto device_name = remote_tensor->get_device_name();
m_tensors.insert({device_name, tensor});
}
}

const std::string& RemoteTensor::get_device_name() const {
return m_context->get_device_name();
}

const ov::element::Type& RemoteTensor::get_element_type() const {
return m_tensors.begin()->second->get_element_type();
}

const ov::Strides& RemoteTensor::get_strides() const {
return m_tensors.begin()->second->get_strides();
}

const AnyMap& RemoteTensor::get_properties() const {
return m_tensors.begin()->second->get_properties();
}

const ov::Shape& RemoteTensor::get_shape() const {
return m_tensors.begin()->second->get_shape();
}

std::shared_ptr<RemoteContext> RemoteTensor::get_context() const {
return m_context;
}

ov::SoPtr<ov::IRemoteTensor> RemoteTensor::get_tensor(int index) const {
return m_ordered_tensor[index];
}

ov::SoPtr<ov::IRemoteTensor> RemoteTensor::get_tensor_by_name(const std::string device_name) const {
return m_tensors.at(device_name);
}

void RemoteTensor::set_shape(ov::Shape shape) {
for (auto it = m_tensors.begin(); it != m_tensors.end(); ++it) {
it->second->set_shape(shape);
}
}

void RemoteTensor::copy_to(const std::shared_ptr<ov::ITensor>& dst,
size_t src_offset,
size_t dst_offset,
const ov::Shape& roi_shape) const {
if (auto remote = std::dynamic_pointer_cast<ov::hetero::RemoteTensor>(dst)) {
int i = 0;
for (auto& tensor : m_remote_tensors) {
auto itensor = std::dynamic_pointer_cast<ov::ITensor>(remote->get_tensor(i)._ptr);
tensor->copy_to(itensor, src_offset, dst_offset, roi_shape);
i++;
}
} else {
int i = 0;
for (auto& tensor : m_remote_tensors) {
tensor->copy_to(dst, src_offset, dst_offset + i * get_strides()[0], roi_shape);
i++;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tried beam search? can it work with hetero? it had problem with tensor parallel, not sure hetero

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi bell, I haven't tried beam search yet, will try it.

}

void RemoteTensor::copy_from(const std::shared_ptr<const ov::ITensor>& src,
size_t src_offset,
size_t dst_offset,
const ov::Shape& roi_shape) {
if (auto remote = std::dynamic_pointer_cast<const ov::hetero::RemoteTensor>(src)) {
int i = 0;
for (auto& tensor : m_remote_tensors) {
auto itensor = std::dynamic_pointer_cast<ov::ITensor>(remote->get_tensor(i)._ptr);
tensor->copy_from(itensor, src_offset, dst_offset, roi_shape);
i++;
}
} else {
auto new_roi_shape = get_shape();
new_roi_shape[0] = roi_shape[0];
int i = 0;
for (auto& tensor : m_remote_tensors) {
tensor->copy_from(src, src_offset + i * get_strides()[0], dst_offset, new_roi_shape);
i++;
}
}
}

} // namespace hetero
} // namespace ov
56 changes: 56 additions & 0 deletions src/plugins/hetero/src/remote_tensor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <map>
#include <string>

#include "openvino/runtime/iremote_tensor.hpp"
#include "remote_context.hpp"

namespace ov {
namespace hetero {

class RemoteTensor : public ov::IRemoteTensor {
public:
RemoteTensor(std::shared_ptr<ov::hetero::RemoteContext> context, std::vector<ov::SoPtr<ov::IRemoteTensor>> tensors);

const std::string& get_device_name() const override;

const ov::element::Type& get_element_type() const override;

const ov::Strides& get_strides() const override;

const AnyMap& get_properties() const override;

const ov::Shape& get_shape() const override;

std::shared_ptr<RemoteContext> get_context() const;

ov::SoPtr<ov::IRemoteTensor> get_tensor(int index) const;

ov::SoPtr<ov::IRemoteTensor> get_tensor_by_name(const std::string device_name) const;

void set_shape(ov::Shape shape) override;

void copy_to(const std::shared_ptr<ov::ITensor>& dst,
size_t src_offset,
size_t dst_offset,
const ov::Shape& roi_shape) const override;

void copy_from(const std::shared_ptr<const ov::ITensor>& src,
size_t src_offset,
size_t dst_offset,
const ov::Shape& roi_shape) override;

private:
std::shared_ptr<RemoteContext> m_context;
std::vector<ov::SoPtr<ov::IRemoteTensor>> m_ordered_tensor;
std::map<std::string, ov::SoPtr<ov::IRemoteTensor>> m_tensors;
std::vector<std::shared_ptr<ov::IRemoteTensor>> m_remote_tensors;
};

} // namespace hetero
} // namespace ov
71 changes: 69 additions & 2 deletions src/plugins/hetero/src/subgraph_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
#include "openvino/core/except.hpp"
#include "openvino/core/graph_util.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/utils/utils.hpp"

namespace {

template <typename Set>
Expand Down Expand Up @@ -606,6 +609,70 @@ std::pair<ov::hetero::SubgraphsVector, ov::hetero::SubgraphsMappingInfo> ov::het
return subgraph_collector.run();
}

void ov::hetero::fix_model_with_paged_attention(std::shared_ptr<ov::Model>& model) {
using NodePtr = std::shared_ptr<ov::Node>;
std::unordered_set<NodePtr> has_visited_transpose;
std::vector<NodePtr> vector_visited_transpose;

std::function<NodePtr(NodePtr)> find_first_transpose_before_pa = [&](NodePtr root_node) -> NodePtr {
auto get_output_node = [](const ov::Output<ov::Node>& output) -> NodePtr {
return output.get_node_shared_ptr();
};
auto get_input_node = [&get_output_node](const ov::Input<ov::Node>& input) -> NodePtr {
return get_output_node(input.get_source_output());
};
auto cur_node = get_input_node(root_node->inputs()[0]);
if (ov::is_type<ov::op::v1::Transpose>(cur_node)) {
return cur_node;
}
return find_first_transpose_before_pa(cur_node);
};

auto find_transpose_ops = [&](NodePtr node) {
for (size_t i = 0; i < 3; i++) {
auto first_transpose_before_pa = find_first_transpose_before_pa(node->get_input_node_shared_ptr(i));
if (has_visited_transpose.insert(first_transpose_before_pa).second) {
vector_visited_transpose.push_back(first_transpose_before_pa);
}
}
};

for (auto& op : model->get_ordered_ops()) {
if (ov::is_type<ov::op::PagedAttentionExtension>(op)) {
find_transpose_ops(op);
} else if (const auto& subgraph = ov::as_type_ptr<ov::hetero::op::DeviceSubgraph>(op)) {
for (auto& node : subgraph->get_function()->get_ordered_ops()) {
if (ov::is_type<ov::op::PagedAttentionExtension>(node)) {
find_transpose_ops(node);
}
}
}
}

for (auto& node : vector_visited_transpose) {
std::map<size_t, NodePtr> org_users;
auto output_shape = node->get_output_partial_shape(0);
int num_key_value_heads = static_cast<int>(output_shape[2].get_length());
int head_dim = static_cast<int>(output_shape[3].get_length());
if (output_shape[1].is_dynamic()) {
for (auto u : node->get_users()) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == node) {
org_users.insert({idx, u});
}
}
}
auto new_shape =
ov::op::v0::Constant::create(element::Type_t::u64, {4}, {-1, 1, num_key_value_heads, head_dim});
auto new_reshape = std::make_shared<ov::op::v1::Reshape>(node, new_shape, false);
for (auto& iter : org_users) {
iter.second->input(iter.first).replace_source_output(new_reshape->output(0));
}
node->clear_control_dependencies();
}
}
}

ov::hetero::SubgraphsMappingInfo ov::hetero::mask_model_subgraphs_by_ops(std::shared_ptr<ov::Model>& model,
ov::SupportedOpsMap& supported_ops,
const bool dump_dot_files,
Expand Down Expand Up @@ -691,7 +758,7 @@ ov::hetero::SubgraphsMappingInfo ov::hetero::mask_model_subgraphs_by_ops(std::sh
merge_submodels(submodels, mapping_info._submodels_input_to_prev_output);

model = submodels[0];

fix_model_with_paged_attention(model);
// Finally update mapping information according to the new operation order
std::map<size_t, size_t> subgraph_id_map;
std::map<size_t, std::map<size_t, size_t>> input_id_map;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/hetero/src/subgraph_collector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,7 @@ SubgraphsMappingInfo mask_model_subgraphs_by_ops(std::shared_ptr<ov::Model>& mod
const bool dump_dot_files = false,
const std::string default_device = "");

void fix_model_with_paged_attention(std::shared_ptr<ov::Model>& model);

} // namespace hetero
} // namespace ov
Loading
Loading