Skip to content

Commit 7cb3bf5

Browse files
authored
[CPU] Enable bf16 acdb layout for transpose (#21030)
1 parent bb3ed2d commit 7cb3bf5

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/plugins/intel_cpu/src/nodes/transpose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ void Transpose::initSupportedPrimitiveDescriptors() {
111111
supportedPrimitiveDescriptorsBuilder(config, transposeParams);
112112
}
113113
#endif // OPENVINO_ARCH_X86_64
114-
if (prec == Precision::FP32 || prec == Precision::FP16 || prec == Precision::I8 || prec == Precision::U8) {
114+
if (prec == Precision::FP32 || prec == Precision::FP16 || prec == Precision::I8 || prec == Precision::U8 || prec == Precision::BF16) {
115115
config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::nspc)->createSharedDesc(prec, inputDataShape));
116116
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::nspc)->createSharedDesc(prec, outputDataShape));
117117
supportedPrimitiveDescriptorsBuilder(config, transposeParams);

src/plugins/intel_cpu/tests/functional/single_layer_tests/instances/x64/transpose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ const auto cpuParams_nCdhw16c = CPUSpecificParams {{nCdhw16c}, {}, {}, {}};
2525

2626
const auto cpuParams_nChw8c = CPUSpecificParams {{nChw8c}, {}, {}, {}};
2727
const auto cpuParams_nCdhw8c = CPUSpecificParams {{nCdhw8c}, {}, {}, {}};
28+
const auto cpuParams_nspc = CPUSpecificParams {{acdb}, {}, {}, {}};
2829

2930
const std::vector<InferenceEngine::Precision> netPrecisions = {
3031
Precision::I8,
@@ -64,7 +65,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamicShapes4D_Transpose, TransposeLayerCPUTest,
6465
::testing::Values(Precision::BF16),
6566
::testing::Values(ov::test::utils::DEVICE_CPU),
6667
::testing::Values(additional_config),
67-
::testing::Values(CPUSpecificParams{})),
68+
::testing::ValuesIn({CPUSpecificParams{}, cpuParams_nspc})),
6869
TransposeLayerCPUTest::getTestCaseName);
6970

7071
const std::vector<InputShape> staticInputShapes5DC16 = {InputShape{

0 commit comments

Comments
 (0)