Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 125 additions & 3 deletions src/plugins/intel_cpu/src/nodes/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
#include "graph_context.h"
#include "memory_desc/cpu_memory_desc.h"
#include "node.h"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/executor_config.hpp"
#include "nodes/executors/executor_factory.hpp"
#include "nodes/executors/implementations.hpp"
#include "nodes/executors/memory_arguments.hpp"
#include "nodes/node_config.h"
#include "onednn/dnnl.h"
#include "openvino/core/except.hpp"
Expand Down Expand Up @@ -92,6 +97,7 @@ Concat::Concat(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr& co
}
CPU_NODE_ASSERT(axis < static_cast<int64_t>(inRank) && axis >= 0, "has invalid value of axis parameter: ", axis);
this->axis = axis;
m_attrs.axis = axis;
}

void Concat::getSupportedDescriptors() {
Expand Down Expand Up @@ -242,12 +248,49 @@ void Concat::initSupportedPrimitiveDescriptors() {
// Optimized inplace case
for (auto refPdIndex : pdIndexesToReuse) {
auto config = supportedPrimitiveDescriptors[refPdIndex].getConfig();
;
for (auto& inConf : config.inConfs) {
inConf.inPlace(0);
}
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}

const auto& concatImplementations = getImplementations<ConcatAttrs>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why we access implementation list here?
In general, we just need to create and use a factory, not directly the implementation list.
Also, after refactoring, 'getSupportedDescriptors()' becomes empty, and most of the logic is moved to 'initSupportedPrimitiveDescriptors' and 'createPrimitive'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (!concatImplementations.empty()) {
const auto& creatorsMap = BlockedDescCreator::getCommonCreators();
auto pushExecutorDesc = [&](LayoutType layoutType) {
NodeConfig nodeConfig;
nodeConfig.outConfs.resize(1);
nodeConfig.inConfs.resize(getParentEdges().size());

MemoryDescArgs descs;
descs.reserve(getParentEdges().size() + 1);
for (size_t i = 0; i < getParentEdges().size(); ++i) {
auto srcDesc = creatorsMap.at(layoutType)->createSharedDesc(inputPrecision, getInputShapeAtPort(i));
nodeConfig.inConfs[i].setMemDesc(srcDesc);
nodeConfig.inConfs[i].inPlace(-1);
nodeConfig.inConfs[i].constant(false);
descs[ARG_SRC + i] = srcDesc;
}

auto dstDesc = creatorsMap.at(layoutType)->createSharedDesc(outputPrecision, getOutputShapeAtPort(0));
nodeConfig.outConfs[0].setMemDesc(dstDesc);
nodeConfig.outConfs[0].inPlace(-1);
nodeConfig.outConfs[0].constant(false);
descs[ARG_DST] = dstDesc;

const executor::Config<ConcatAttrs> config{descs, m_attrs};
const bool supported =
std::any_of(concatImplementations.begin(), concatImplementations.end(), [&](const auto& impl) {
return impl.supports(config, memoryFormatFilter);
});
if (supported) {
supportedPrimitiveDescriptors.emplace_back(nodeConfig, impl_desc_type::undef);
}
};

pushExecutorDesc(LayoutType::ncsp);
pushExecutorDesc(LayoutType::nspc);
}
}

void Concat::selectOptimalPrimitiveDescriptor() {
Expand Down Expand Up @@ -354,14 +397,27 @@ void Concat::selectOptimalPrimitiveDescriptor() {
return;
}

// if there are more than one PD with similar data layouts - select the optimized one
for (auto indx : canSelectPrimitive) {
if (supportedPrimitiveDescriptors[indx].getImplementationType() == impl_desc_type::undef) {
selectPrimitiveDescriptorByIndex(static_cast<int>(indx));
return;
}
}

for (auto indx : canSelectPrimitive) {
if (supportedPrimitiveDescriptors[indx].getImplementationType() == impl_desc_type::unknown) {
selectPrimitiveDescriptorByIndex(static_cast<int>(indx));
return;
}
}

for (auto indx : canSelectPrimitive) {
if (supportedPrimitiveDescriptors[indx].getImplementationType() == impl_desc_type::ref) {
selectPrimitiveDescriptorByIndex(static_cast<int>(indx));
return;
}
}

// if there are no matching data layouts, select first optimized implementation
for (size_t i = 0; i < supportedPrimitiveDescriptors.size(); i++) {
if (canBeInPlace && supportedPrimitiveDescriptors[i].getImplementationType() == impl_desc_type::unknown) {
Expand Down Expand Up @@ -389,10 +445,29 @@ void Concat::prepareParams() {
return;
}

auto* selectedPd = getSelectedPrimitiveDescriptor();
CPU_NODE_ASSERT(selectedPd, "Preferable primitive descriptor is not set.");

if (useExecutor && m_executor) {
for (size_t i = 0; i < getParentEdges().size(); ++i) {
m_memory[ARG_SRC + i] = getSrcMemoryAtPort(i);
}
m_memory[ARG_DST] = getDstMemoryAtPort(0);

if (m_executor->update(m_memory)) {
selectedPd->setImplementationType(m_executor->implType());
return;
}

// Fallback to oneDNN/ref concat when executor update is not applicable for runtime shapes.
useExecutor = false;
m_executor.reset();
selectedPd->setImplementationType(impl_desc_type::ref);
}

const auto& dstMemPtr = getDstMemoryAtPort(0);
CPU_NODE_ASSERT(dstMemPtr && dstMemPtr->isDefined(), "Destination memory is undefined.");
auto dstMemDesc = dstMemPtr->getDescWithType<BlockedMemoryDesc>();
CPU_NODE_ASSERT(getSelectedPrimitiveDescriptor(), "Preferable primitive descriptor is not set.");

const auto& outputStrides = dstMemDesc->getStrides();
size_t curConcatOffset = 0;
Expand Down Expand Up @@ -502,6 +577,46 @@ size_t Concat::inverseOrder(const VectorDims& order, size_t axis) {
return -1;
}

void Concat::createPrimitive() {
auto* selectedPd = getSelectedPrimitiveDescriptor();
CPU_NODE_ASSERT(selectedPd, "Preferable primitive descriptor is not set.");

if (!isInPlace()) {
m_memory.clear();
m_memory.reserve(getParentEdges().size() + 1);
for (size_t i = 0; i < getParentEdges().size(); ++i) {
m_memory[ARG_SRC + i] = getSrcMemoryAtPort(i);
}
m_memory[ARG_DST] = getDstMemoryAtPort(0);

useExecutor = selectedPd->getImplementationType() == impl_desc_type::undef && !canOptimizeNspc;
m_executor.reset();

if (useExecutor) {
MemoryDescArgs descs;
descs.reserve(m_memory.size());
for (const auto& [arg, mem] : m_memory) {
descs[arg] = mem->getDescPtr();
}

try {
auto executionContext = std::make_shared<ExecutorContext>(context, getImplPriority());
auto factory = std::make_shared<ExecutorFactory<ConcatAttrs>>(m_attrs,
executionContext,
descs,
memoryFormatFilter);
m_executor = factory->make(m_memory);
selectedPd->setImplementationType(m_executor->implType());
} catch (...) {
useExecutor = false;
m_executor.reset();
}
}
}

Node::createPrimitive();
}

void Concat::initOptimalPrimitiveDescriptor() {
auto* selected_pd = getSelectedPrimitiveDescriptor();
CPU_NODE_ASSERT(selected_pd, "Preferable primitive descriptor is not set.");
Expand All @@ -525,6 +640,8 @@ void Concat::initOptimalPrimitiveDescriptor() {
}
}

useExecutor = selected_pd->getImplementationType() == impl_desc_type::undef;

// block layout may have axis greater than rank, disable ref_concat
auto* primDesc = getSelectedPrimitiveDescriptor();
auto* memDesc = primDesc->getConfig().outConfs[0].getMemDesc()->as<BlockedMemoryDesc>();
Expand Down Expand Up @@ -558,6 +675,11 @@ void Concat::execute(const dnnl::stream& strm) {
return;
}

if (useExecutor && m_executor) {
m_executor->execute(m_memory);
return;
}

if (canOptimize1DCase) {
exec1DCase();
return;
Expand Down
8 changes: 8 additions & 0 deletions src/plugins/intel_cpu/src/nodes/concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include "edge.h"
#include "graph_context.h"
#include "node.h"
#include "nodes/executors/concat.hpp"
#include "nodes/executors/executor.hpp"
#include "nodes/executors/memory_arguments.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type/element_type.hpp"

Expand Down Expand Up @@ -41,6 +44,7 @@ class Concat : public Node {
[[nodiscard]] bool isExecutable() const override;
[[nodiscard]] bool needPrepareParams() const override;
void prepareParams() override;
void createPrimitive() override;
// TODO: Move to base Node class when more nodes support fuse convert
bool supportConvertFusion() const {
return supportFuseConvert;
Expand Down Expand Up @@ -70,6 +74,10 @@ class Concat : public Node {
bool doFuseConvert = false; // whether to perform FP16 to FP32 conversion
static constexpr size_t MAX_RANK_REF = 6;
dnnl::primitive prim;
bool useExecutor = false;
ConcatAttrs m_attrs;
MemoryArgs m_memory;
ExecutorPtr m_executor = nullptr;
};

} // namespace ov::intel_cpu::node
Loading
Loading