Skip to content

Commit 3c49e56

Browse files
committed
[cker] Fix array-bounds build error: Refactor Shape.h to use C++17 std::variant
This commit improves type-safety and memory management by leveraging modern C++17 features. - Replace the old union-based storage (_dims and _dims_pointer) with a std::variant that holds either a std::array (for shapes with ≤ 6 dimensions) or a std::vector (for larger shapes) - Update constructors to initialize and resize dims_ accordingly, ensuring that small shapes use the fixed-size array and larger ones use dynamic allocation - Implement a custom copy constructor to perform a deep copy based on the current storage type and default the move constructor - Update various member functions (Dims, SetDim, DimsData, Resize, ReplaceWith, BuildFrom, etc.) to use the new std::variant-based storage - Minor include reordering and addition of necessary headers (e.g., <array>, <iterator>, <variant>) ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent d172cc3 commit 3c49e56

2 files changed

Lines changed: 106 additions & 75 deletions

File tree

runtime/compute/cker/include/cker/Shape.h

Lines changed: 105 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
#define __NNFW_CKER_SHAPE_H__
2020

2121
#include <algorithm>
22-
#include <cstring>
22+
#include <array>
2323
#include <cassert>
24+
#include <cstring>
25+
#include <iterator>
26+
#include <variant>
2427
#include <vector>
2528

2629
namespace nnfw
@@ -35,18 +38,25 @@ class Shape
3538
// larger shapes are separately allocated.
3639
static constexpr int kMaxSmallSize = 6;
3740

41+
// Delete copy assignment operator.
3842
Shape &operator=(Shape const &) = delete;
3943

40-
Shape() : _size(0) {}
44+
// Default constructor: initializes an empty shape (size = 0) with small storage.
45+
Shape() : _size(0), dims_(std::array<int32_t, kMaxSmallSize>{}) {}
4146

47+
// Constructor that takes a dimension count.
48+
// If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
49+
// Otherwise, it uses a dynamic vector.
4250
explicit Shape(int dimensions_count) : _size(dimensions_count)
4351
{
44-
if (dimensions_count > kMaxSmallSize)
45-
{
46-
_dims_pointer = new int32_t[dimensions_count];
52+
if (dimensions_count <= kMaxSmallSize) {
53+
dims_ = std::array<int32_t, kMaxSmallSize>{};
54+
} else {
55+
dims_ = std::vector<int32_t>(dimensions_count);
4756
}
4857
}
4958

59+
// Constructor that creates a shape of given size and fills all dimensions with "value".
5060
Shape(int shape_size, int32_t value) : _size(0)
5161
{
5262
Resize(shape_size);
@@ -56,136 +66,155 @@ class Shape
5666
}
5767
}
5868

69+
// Constructor that creates a shape from an array of dimension data.
5970
Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
6071
{
6172
ReplaceWith(dimensions_count, dims_data);
6273
}
6374

64-
Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
75+
// Initializer list constructor.
76+
// Marked explicit to avoid unintended overload resolution.
77+
Shape(const std::initializer_list<int> init_list) : _size(0)
78+
{
79+
BuildFrom(init_list);
80+
}
6581

66-
// Avoid using this constructor. We should be able to delete it when C++17
67-
// rolls out.
68-
Shape(Shape const &other) : _size(other.DimensionsCount())
82+
// Copy constructor
83+
Shape(const Shape &other) : _size(other._size)
6984
{
70-
if (_size > kMaxSmallSize)
71-
{
72-
_dims_pointer = new int32_t[_size];
85+
if (_size <= kMaxSmallSize) {
86+
// When the number of dimensions is small, copy the fixed array.
87+
dims_ = std::get<std::array<int32_t, kMaxSmallSize>>(other.dims_);
88+
} else {
89+
// Otherwise, copy the dynamically allocated vector.
90+
dims_ = std::get<std::vector<int32_t>>(other.dims_);
7391
}
74-
std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
7592
}
93+
Shape(Shape &&other) = default;
7694

7795
bool operator==(const Shape &comp) const
7896
{
7997
return this->_size == comp._size &&
8098
std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
8199
}
82100

83-
~Shape()
84-
{
85-
if (_size > kMaxSmallSize)
86-
{
87-
delete[] _dims_pointer;
88-
}
89-
}
101+
~Shape() = default;
90102

103+
// Returns the number of dimensions.
91104
inline int32_t DimensionsCount() const { return _size; }
105+
106+
// Returns the dimension size at index i.
92107
inline int32_t Dims(int i) const
93108
{
94-
assert(i >= 0);
95-
assert(i < _size);
96-
return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
109+
assert(i >= 0 && i < _size);
110+
if (_size <= kMaxSmallSize) {
111+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i];
112+
} else {
113+
return std::get<std::vector<int32_t>>(dims_)[i];
114+
}
97115
}
116+
117+
// Sets the dimension at index i.
98118
inline void SetDim(int i, int32_t val)
99119
{
100-
assert(i >= 0);
101-
assert(i < _size);
102-
if (_size > kMaxSmallSize)
103-
{
104-
_dims_pointer[i] = val;
120+
assert(i >= 0 && i < _size);
121+
if (_size <= kMaxSmallSize) {
122+
std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i] = val;
123+
} else {
124+
std::get<std::vector<int32_t>>(dims_)[i] = val;
105125
}
106-
else
107-
{
108-
_dims[i] = val;
126+
}
127+
128+
// Returns a pointer to the dimension data (mutable).
129+
inline int32_t* DimsData()
130+
{
131+
if (_size <= kMaxSmallSize) {
132+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
133+
} else {
134+
return std::get<std::vector<int32_t>>(dims_).data();
109135
}
110136
}
111137

112-
inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113-
inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114-
// The caller must ensure that the shape is no bigger than 6D.
115-
inline const int32_t *DimsDataUpTo6D() const { return _dims; }
138+
// Returns a pointer to the dimension data (const).
139+
inline const int32_t* DimsData() const
140+
{
141+
if (_size <= kMaxSmallSize) {
142+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
143+
} else {
144+
return std::get<std::vector<int32_t>>(dims_).data();
145+
}
146+
}
116147

148+
// The caller must ensure that the shape is no larger than 6D.
149+
inline const int32_t* DimsDataUpTo6D() const {
150+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
151+
}
152+
153+
// Resizes the shape to dimensions_count.
117154
inline void Resize(int dimensions_count)
118155
{
119-
if (_size > kMaxSmallSize)
120-
{
121-
delete[] _dims_pointer;
122-
}
123156
_size = dimensions_count;
124-
if (dimensions_count > kMaxSmallSize)
125-
{
126-
_dims_pointer = new int32_t[dimensions_count];
157+
if (dimensions_count <= kMaxSmallSize) {
158+
dims_ = std::array<int32_t, kMaxSmallSize>{};
159+
} else {
160+
dims_ = std::vector<int32_t>(dimensions_count);
127161
}
128162
}
129163

164+
// Replaces the current shape with a new one defined by dimensions_count and dims_data.
130165
inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
131166
{
132167
Resize(dimensions_count);
133-
int32_t *dst_dims = DimsData();
134-
std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
168+
std::memcpy(DimsData(), dims_data, dimensions_count * sizeof(int32_t));
135169
}
136170

171+
// Replaces the current shape with another shape.
137172
inline void ReplaceWith(const Shape &other)
138173
{
139174
ReplaceWith(other.DimensionsCount(), other.DimsData());
140175
}
141176

177+
// Replaces the current shape with another shape using move semantics.
142178
inline void ReplaceWith(Shape &&other)
143179
{
144-
Resize(0);
145180
std::swap(_size, other._size);
146-
if (_size <= kMaxSmallSize)
147-
std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
148-
else
149-
_dims_pointer = other._dims_pointer;
181+
dims_ = std::move(other.dims_);
150182
}
151183

152-
template <typename T> inline void BuildFrom(const T &src_iterable)
184+
// Builds the shape from an iterable sequence.
185+
template <typename Iterable>
186+
inline void BuildFrom(const Iterable &src_iterable)
153187
{
154-
const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
188+
const int dimensions_count = static_cast<int>(std::distance(src_iterable.begin(), src_iterable.end()));
155189
Resize(dimensions_count);
156-
int32_t *data = DimsData();
157-
for (auto &&it : src_iterable)
190+
int32_t* data = DimsData();
191+
for (auto it = src_iterable.begin(); it != src_iterable.end(); ++it)
158192
{
159-
*data = it;
160-
++data;
193+
*data++ = static_cast<int32_t>(*it);
161194
}
162195
}
163196

164-
// This will probably be factored out. Old code made substantial use of 4-D
165-
// shapes, and so this function is used to extend smaller shapes. Note that
166-
// (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167-
// reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168-
// inputs should already be 4-D, so this function should not be needed.
197+
// Returns the total count of elements, that is the size when flattened into a
198+
// vector.
169199
inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
170200
{
171201
return Shape(new_shape_size, shape, 1);
172202
}
173203

204+
// Overload for initializer list building.
174205
inline void BuildFrom(const std::initializer_list<int> init_list)
175206
{
176207
BuildFrom<const std::initializer_list<int>>(init_list);
177208
}
178209

179-
// Returns the total count of elements, that is the size when flattened into a
180-
// vector.
210+
// Returns the total count of elements (flattened size).
181211
inline int FlatSize() const
182212
{
183213
int buffer_size = 1;
184-
const int *dims_data = DimsData();
214+
const int* dims_data = DimsData();
185215
for (int i = 0; i < _size; i++)
186216
{
187-
const int dim = dims_data[i];
188-
buffer_size *= dim;
217+
buffer_size *= dims_data[i];
189218
}
190219
return buffer_size;
191220
}
@@ -206,17 +235,17 @@ class Shape
206235
{
207236
SetDim(i, pad_value);
208237
}
209-
std::memcpy(DimsData() + size_increase, shape.DimsData(),
210-
sizeof(int32_t) * shape.DimensionsCount());
238+
std::memcpy(DimsData() + size_increase, shape.DimsData(), sizeof(int32_t) * shape.DimensionsCount());
211239
}
212240

213241
int32_t _size;
214-
union {
215-
int32_t _dims[kMaxSmallSize];
216-
int32_t *_dims_pointer{nullptr};
217-
};
242+
// Internal storage: use std::array for shapes with dimensions up to kMaxSmallSize,
243+
// and std::vector for larger shapes.
244+
std::variant<std::array<int32_t, kMaxSmallSize>, std::vector<int32_t>> dims_;
218245
};
219246

247+
// Utility functions below.
248+
220249
inline int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
221250
[[maybe_unused]] int index2)
222251
{
@@ -232,7 +261,10 @@ int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &s
232261
return MatchingDim(shape1, index1, args...);
233262
}
234263

235-
inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
264+
inline Shape GetShape(const std::vector<int32_t> &data)
265+
{
266+
return Shape(static_cast<int>(data.size()), data.data());
267+
}
236268

237269
inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
238270
{
@@ -278,8 +310,7 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
278310
return flat_size;
279311
}
280312

281-
// Flat size calculation, checking that dimensions match with one or more other
282-
// arrays.
313+
// Flat size calculation, checking that dimensions match with one or more other shapes.
283314
template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
284315
{
285316
auto match = [&shape](const Shape &s) -> bool {

runtime/compute/cker/include/cker/operation/Einsum.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ class Einsum
871871
Tensor rhs;
872872
reshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs);
873873
Shape old_output_shape = bcast.output_batch_shape();
874-
Shape output_shape(old_output_shape.DimensionsCount() + inputs.size());
874+
Shape output_shape(static_cast<int>(old_output_shape.DimensionsCount() + inputs.size()));
875875
for (int i = 0; i < old_output_shape.DimensionsCount(); i++)
876876
{
877877
output_shape.SetDim(i, old_output_shape.Dims(i));

0 commit comments

Comments
 (0)