Skip to content

Commit 0fcfa6a

Browse files
committed
[cker] Fix array-bounds build error: Refactor Shape.h to use C++17 std::variant
This commit improves type-safety and memory management by leveraging modern C++17 features. - Replace the old union-based storage (_dims and _dims_pointer) with a std::variant that holds either a std::array (for shapes with ≤ 6 dimensions) or a std::vector (for larger shapes) - Update constructors to initialize and resize dims_ accordingly, ensuring that small shapes use the fixed-size array and larger ones use dynamic allocation - Implement a custom copy constructor to perform a deep copy based on the current storage type and default the move constructor - Update various member functions (Dims, SetDim, DimsData, Resize, ReplaceWith, BuildFrom, etc.) to use the new std::variant-based storage - Minor include reordering and addition of necessary headers (e.g., <array>, <iterator>, <variant>) ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent d172cc3 commit 0fcfa6a

2 files changed

Lines changed: 113 additions & 62 deletions

File tree

runtime/compute/cker/include/cker/Shape.h

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
#define __NNFW_CKER_SHAPE_H__
2020

2121
#include <algorithm>
22-
#include <cstring>
22+
#include <array>
2323
#include <cassert>
24+
#include <cstring>
25+
#include <iterator>
26+
#include <variant>
2427
#include <vector>
2528

2629
namespace nnfw
@@ -35,18 +38,28 @@ class Shape
3538
// larger shapes are separately allocated.
3639
static constexpr int kMaxSmallSize = 6;
3740

41+
// Delete copy assignment operator.
3842
Shape &operator=(Shape const &) = delete;
3943

40-
Shape() : _size(0) {}
44+
// Default constructor: initializes an empty shape (size = 0) with small storage.
45+
Shape() : _size(0), dims_(std::array<int32_t, kMaxSmallSize>{}) {}
4146

47+
// Constructor that takes a dimension count.
48+
// If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
49+
// Otherwise, it uses a dynamic vector.
4250
explicit Shape(int dimensions_count) : _size(dimensions_count)
4351
{
44-
if (dimensions_count > kMaxSmallSize)
52+
if (dimensions_count <= kMaxSmallSize)
4553
{
46-
_dims_pointer = new int32_t[dimensions_count];
54+
dims_ = std::array<int32_t, kMaxSmallSize>{};
55+
}
56+
else
57+
{
58+
dims_ = std::vector<int32_t>(dimensions_count);
4759
}
4860
}
4961

62+
// Constructor that creates a shape of given size and fills all dimensions with "value".
5063
Shape(int shape_size, int32_t value) : _size(0)
5164
{
5265
Resize(shape_size);
@@ -56,136 +69,171 @@ class Shape
5669
}
5770
}
5871

72+
// Constructor that creates a shape from an array of dimension data.
5973
Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
6074
{
6175
ReplaceWith(dimensions_count, dims_data);
6276
}
6377

78+
// Initializer list constructor.
79+
// Marked explicit to avoid unintended overload resolution.
6480
Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
6581

66-
// Avoid using this constructor. We should be able to delete it when C++17
67-
// rolls out.
68-
Shape(Shape const &other) : _size(other.DimensionsCount())
82+
// Copy constructor
83+
Shape(const Shape &other) : _size(other._size)
6984
{
70-
if (_size > kMaxSmallSize)
85+
if (_size <= kMaxSmallSize)
86+
{
87+
// When the number of dimensions is small, copy the fixed array.
88+
dims_ = std::get<std::array<int32_t, kMaxSmallSize>>(other.dims_);
89+
}
90+
else
7191
{
72-
_dims_pointer = new int32_t[_size];
92+
// Otherwise, copy the dynamically allocated vector.
93+
dims_ = std::get<std::vector<int32_t>>(other.dims_);
7394
}
74-
std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
7595
}
96+
Shape(Shape &&other) = default;
7697

7798
bool operator==(const Shape &comp) const
7899
{
79100
return this->_size == comp._size &&
80101
std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
81102
}
82103

83-
~Shape()
104+
~Shape() = default;
105+
106+
// Returns the number of dimensions.
107+
inline int32_t DimensionsCount() const { return _size; }
108+
109+
// Returns the dimension size at index i.
110+
inline int32_t Dims(int i) const
84111
{
85-
if (_size > kMaxSmallSize)
112+
assert(i >= 0 && i < _size);
113+
if (_size <= kMaxSmallSize)
86114
{
87-
delete[] _dims_pointer;
115+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i];
116+
}
117+
else
118+
{
119+
return std::get<std::vector<int32_t>>(dims_)[i];
88120
}
89121
}
90122

91-
inline int32_t DimensionsCount() const { return _size; }
92-
inline int32_t Dims(int i) const
123+
// Sets the dimension at index i.
124+
inline void SetDim(int i, int32_t val)
93125
{
94-
assert(i >= 0);
95-
assert(i < _size);
96-
return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
126+
assert(i >= 0 && i < _size);
127+
if (_size <= kMaxSmallSize)
128+
{
129+
std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i] = val;
130+
}
131+
else
132+
{
133+
std::get<std::vector<int32_t>>(dims_)[i] = val;
134+
}
97135
}
98-
inline void SetDim(int i, int32_t val)
136+
137+
// Returns a pointer to the dimension data (mutable).
138+
inline int32_t *DimsData()
99139
{
100-
assert(i >= 0);
101-
assert(i < _size);
102-
if (_size > kMaxSmallSize)
140+
if (_size <= kMaxSmallSize)
103141
{
104-
_dims_pointer[i] = val;
142+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
105143
}
106144
else
107145
{
108-
_dims[i] = val;
146+
return std::get<std::vector<int32_t>>(dims_).data();
109147
}
110148
}
111149

112-
inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113-
inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114-
// The caller must ensure that the shape is no bigger than 6D.
115-
inline const int32_t *DimsDataUpTo6D() const { return _dims; }
150+
// Returns a pointer to the dimension data (const).
151+
inline const int32_t *DimsData() const
152+
{
153+
if (_size <= kMaxSmallSize)
154+
{
155+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
156+
}
157+
else
158+
{
159+
return std::get<std::vector<int32_t>>(dims_).data();
160+
}
161+
}
162+
163+
// The caller must ensure that the shape is no larger than 6D.
164+
inline const int32_t *DimsDataUpTo6D() const
165+
{
166+
return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
167+
}
116168

169+
// Resizes the shape to dimensions_count.
117170
inline void Resize(int dimensions_count)
118171
{
119-
if (_size > kMaxSmallSize)
172+
_size = dimensions_count;
173+
if (dimensions_count <= kMaxSmallSize)
120174
{
121-
delete[] _dims_pointer;
175+
dims_ = std::array<int32_t, kMaxSmallSize>{};
122176
}
123-
_size = dimensions_count;
124-
if (dimensions_count > kMaxSmallSize)
177+
else
125178
{
126-
_dims_pointer = new int32_t[dimensions_count];
179+
dims_ = std::vector<int32_t>(dimensions_count);
127180
}
128181
}
129182

183+
// Replaces the current shape with a new one defined by dimensions_count and dims_data.
130184
inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
131185
{
132186
Resize(dimensions_count);
133-
int32_t *dst_dims = DimsData();
134-
std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
187+
std::memcpy(DimsData(), dims_data, dimensions_count * sizeof(int32_t));
135188
}
136189

190+
// Replaces the current shape with another shape.
137191
inline void ReplaceWith(const Shape &other)
138192
{
139193
ReplaceWith(other.DimensionsCount(), other.DimsData());
140194
}
141195

196+
// Replaces the current shape with another shape using move semantics.
142197
inline void ReplaceWith(Shape &&other)
143198
{
144-
Resize(0);
145199
std::swap(_size, other._size);
146-
if (_size <= kMaxSmallSize)
147-
std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
148-
else
149-
_dims_pointer = other._dims_pointer;
200+
dims_ = std::move(other.dims_);
150201
}
151202

152-
template <typename T> inline void BuildFrom(const T &src_iterable)
203+
// Builds the shape from an iterable sequence.
204+
template <typename Iterable> inline void BuildFrom(const Iterable &src_iterable)
153205
{
154-
const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
206+
const int dimensions_count =
207+
static_cast<int>(std::distance(src_iterable.begin(), src_iterable.end()));
155208
Resize(dimensions_count);
156209
int32_t *data = DimsData();
157-
for (auto &&it : src_iterable)
210+
for (auto it = src_iterable.begin(); it != src_iterable.end(); ++it)
158211
{
159-
*data = it;
160-
++data;
212+
*data++ = static_cast<int32_t>(*it);
161213
}
162214
}
163215

164-
// This will probably be factored out. Old code made substantial use of 4-D
165-
// shapes, and so this function is used to extend smaller shapes. Note that
166-
// (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167-
// reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168-
// inputs should already be 4-D, so this function should not be needed.
216+
// Returns the total count of elements, that is the size when flattened into a
217+
// vector.
169218
inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
170219
{
171220
return Shape(new_shape_size, shape, 1);
172221
}
173222

223+
// Overload for initializer list building.
174224
inline void BuildFrom(const std::initializer_list<int> init_list)
175225
{
176226
BuildFrom<const std::initializer_list<int>>(init_list);
177227
}
178228

179-
// Returns the total count of elements, that is the size when flattened into a
180-
// vector.
229+
// Returns the total count of elements (flattened size).
181230
inline int FlatSize() const
182231
{
183232
int buffer_size = 1;
184233
const int *dims_data = DimsData();
185234
for (int i = 0; i < _size; i++)
186235
{
187-
const int dim = dims_data[i];
188-
buffer_size *= dim;
236+
buffer_size *= dims_data[i];
189237
}
190238
return buffer_size;
191239
}
@@ -211,12 +259,13 @@ class Shape
211259
}
212260

213261
int32_t _size;
214-
union {
215-
int32_t _dims[kMaxSmallSize];
216-
int32_t *_dims_pointer{nullptr};
217-
};
262+
// Internal storage: use std::array for shapes with dimensions up to kMaxSmallSize,
263+
// and std::vector for larger shapes.
264+
std::variant<std::array<int32_t, kMaxSmallSize>, std::vector<int32_t>> dims_;
218265
};
219266

267+
// Utility functions below.
268+
220269
inline int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
221270
[[maybe_unused]] int index2)
222271
{
@@ -232,7 +281,10 @@ int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &s
232281
return MatchingDim(shape1, index1, args...);
233282
}
234283

235-
inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
284+
inline Shape GetShape(const std::vector<int32_t> &data)
285+
{
286+
return Shape(static_cast<int>(data.size()), data.data());
287+
}
236288

237289
inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
238290
{
@@ -278,8 +330,7 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
278330
return flat_size;
279331
}
280332

281-
// Flat size calculation, checking that dimensions match with one or more other
282-
// arrays.
333+
// Flat size calculation, checking that dimensions match with one or more other shapes.
283334
template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
284335
{
285336
auto match = [&shape](const Shape &s) -> bool {

runtime/compute/cker/include/cker/operation/Einsum.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ class Einsum
871871
Tensor rhs;
872872
reshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs);
873873
Shape old_output_shape = bcast.output_batch_shape();
874-
Shape output_shape(old_output_shape.DimensionsCount() + inputs.size());
874+
Shape output_shape(static_cast<int>(old_output_shape.DimensionsCount() + inputs.size()));
875875
for (int i = 0; i < old_output_shape.DimensionsCount(); i++)
876876
{
877877
output_shape.SetDim(i, old_output_shape.Dims(i));

0 commit comments

Comments
 (0)