Skip to content

Commit 2fabc60

Browse files
authored
Refactor reshape function to avoid u4/i4 type has memory issue (#34478)
### Details: - The `shape_size(in_shape) * elem_size` will double with real memory size when element is `u4/i4`, this will make the copy access the memory out of allocated. ### Tickets: - [CVS-179197](https://jira.devtools.intel.com/browse/CVS-179197) ### AI Assistance: - AI assistance used: yes - Construct a testcase instead of execute webNN fullstack to duplicate the issue. - Give the AI the callstack and it give me some suggestion
1 parent 5ed7d41 commit 2fabc60

File tree

4 files changed

+40
-5
lines changed

4 files changed

+40
-5
lines changed

src/core/reference/include/openvino/reference/reshape.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,28 @@ namespace reference {
1212

1313
/**
1414
* @brief Basic reshape operation, without axes reorder.
15+
* This only for byte like types(elem size is N*8)
1516
*
1617
* @param in Pointer to input data.
1718
* @param out Pointer to output data.
1819
* @param in_shape Input data shape.
19-
* @param out_shape Output data shape.
20-
* @param elem_size Single data element size im bytes.
20+
* @param elem_size Single data element size in bytes.
2121
*/
2222
inline void reshape(const char* in, char* out, const Shape& in_shape, size_t elem_size) {
2323
std::memcpy(out, in, shape_size(in_shape) * elem_size);
2424
}
2525

26+
/**
27+
* @brief Basic reshape operation copy with real size, without axes reorder.
28+
*
29+
* @param in Pointer to input data.
30+
* @param out Pointer to output data.
31+
* @param copy_size Number of bytes to copy.
32+
*/
33+
inline void reshape(const char* in, char* out, size_t copy_size) {
34+
std::memcpy(out, in, copy_size);
35+
}
36+
2637
/**
2738
* @brief Basic reshape operation for string element type.
2839
* Input data are simply copied to the output.

src/core/src/op/reshape.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ bool Reshape::evaluate_reshape(TensorVector& outputs, const TensorVector& inputs
6767
} else {
6868
ov::reference::reshape(static_cast<const char*>(inputs[0].data()),
6969
static_cast<char*>(outputs[0].data()),
70-
inputs[0].get_shape(),
71-
inputs[0].get_element_type().size());
70+
inputs[0].get_byte_size());
7271
}
7372
return true;
7473
}

src/plugins/template/tests/functional/op_reference/base_reference_test.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ template <class T>
4848
ov::Tensor CreateTensor(const ov::element::Type& element_type, const std::vector<T>& values, size_t size = 0) {
4949
size_t real_size = size ? size : values.size() * sizeof(T) / element_type.size();
5050
ov::Tensor tensor{element_type, {real_size}};
51-
std::memcpy(tensor.data(), values.data(), std::min(real_size * element_type.size(), sizeof(T) * values.size()));
51+
std::memcpy(tensor.data(), values.data(), std::min(tensor.get_byte_size(), sizeof(T) * values.size()));
5252

5353
return tensor;
5454
}

src/plugins/template/tests/functional/op_reference/reshape.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,29 @@ std::vector<ReshapeParams> generateParamsForReshape8Bit() {
299299
return params;
300300
}
301301

302+
template <element::Type_t ET>
303+
std::vector<ReshapeParams> generateParamsForReshape4Bit() {
304+
using T = typename element_type_traits<ET>::value_type;
305+
306+
std::vector<ReshapeParams> params{
307+
ReshapeParams(Shape{2, 2, 3},
308+
Shape{12},
309+
ET,
310+
ET,
311+
std::vector<T>{0x12, 0x34, 0x56, 0x78, 0x1A, 0x2C},
312+
std::vector<T>{0x12, 0x34, 0x56, 0x78, 0x1A, 0x2C},
313+
false),
314+
ReshapeParams(Shape{1, 1, 1}, Shape{}, ET, ET, std::vector<T>{6}, std::vector<T>{6}, false),
315+
ReshapeParams(Shape{}, Shape{1, 1, 1, 1, 1, 1}, ET, ET, std::vector<T>{7}, std::vector<T>{7}, false),
316+
ReshapeParams(Shape{3}, Shape{3, 1}, ET, ET, std::vector<T>{0x12, 0x03}, std::vector<T>{0x12, 0x03}, false),
317+
ReshapeParams(Shape{3}, Shape{1, 3}, ET, ET, std::vector<T>{0x12, 0x03}, std::vector<T>{0x12, 0x03}, false),
318+
ReshapeParams(Shape{3}, Shape{1, 3, 1}, ET, ET, std::vector<T>{0x12, 0x03}, std::vector<T>{0x12, 0x03}, false),
319+
ReshapeParams(Shape{1}, Shape{}, ET, ET, std::vector<T>{1}, std::vector<T>{1}, false),
320+
ReshapeParams(Shape{}, Shape{}, ET, ET, std::vector<T>{1}, std::vector<T>{1}, false)};
321+
322+
return params;
323+
}
324+
302325
template <element::Type_t ET>
303326
std::vector<ReshapeShuffleParams> generateParamsForReshapeShuffle() {
304327
using T = typename element_type_traits<ET>::value_type;
@@ -325,6 +348,8 @@ std::vector<ReshapeParams> generateCombinedParamsForReshape() {
325348
generateParamsForReshape<element::Type_t::u16>(),
326349
generateParamsForReshape8Bit<element::Type_t::i8>(),
327350
generateParamsForReshape8Bit<element::Type_t::u8>(),
351+
generateParamsForReshape4Bit<element::Type_t::i4>(),
352+
generateParamsForReshape4Bit<element::Type_t::u4>(),
328353
generateParamsForReshapeString()};
329354

330355
std::vector<ReshapeParams> combinedParams;

0 commit comments

Comments
 (0)