1919#define __NNFW_CKER_SHAPE_H__
2020
2121#include < algorithm>
22- #include < cstring >
22+ #include < array >
2323#include < cassert>
24+ #include < cstring>
25+ #include < iterator>
26+ #include < variant>
2427#include < vector>
2528
2629namespace nnfw
@@ -35,18 +38,25 @@ class Shape
3538 // larger shapes are separately allocated.
3639 static constexpr int kMaxSmallSize = 6 ;
3740
41+ // Delete copy assignment operator.
3842 Shape &operator =(Shape const &) = delete ;
3943
40- Shape () : _size(0 ) {}
44+ // Default constructor: initializes an empty shape (size = 0) with small storage.
45+ Shape () : _size(0 ), dims_(std::array<int32_t , kMaxSmallSize >{}) {}
4146
47+ // Constructor that takes a dimension count.
48+ // If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
49+ // Otherwise, it uses a dynamic vector.
4250 explicit Shape (int dimensions_count) : _size(dimensions_count)
4351 {
44- if (dimensions_count > kMaxSmallSize )
45- {
46- _dims_pointer = new int32_t [dimensions_count];
52+ if (dimensions_count <= kMaxSmallSize ) {
53+ dims_ = std::array<int32_t , kMaxSmallSize >{};
54+ } else {
55+ dims_ = std::vector<int32_t >(dimensions_count);
4756 }
4857 }
4958
59+ // Constructor that creates a shape of given size and fills all dimensions with "value".
5060 Shape (int shape_size, int32_t value) : _size(0 )
5161 {
5262 Resize (shape_size);
@@ -56,136 +66,155 @@ class Shape
5666 }
5767 }
5868
69+ // Constructor that creates a shape from an array of dimension data.
5970 Shape (int dimensions_count, const int32_t *dims_data) : _size(0 )
6071 {
6172 ReplaceWith (dimensions_count, dims_data);
6273 }
6374
64- Shape (const std::initializer_list<int > init_list) : _size(0 ) { BuildFrom (init_list); }
75+ // Initializer list constructor.
76+ // Marked explicit to avoid unintended overload resolution.
77+ Shape (const std::initializer_list<int > init_list) : _size(0 )
78+ {
79+ BuildFrom (init_list);
80+ }
6581
66- // Avoid using this constructor. We should be able to delete it when C++17
67- // rolls out.
68- Shape (Shape const &other) : _size(other.DimensionsCount())
82+ // Copy constructor
83+ Shape (const Shape &other) : _size(other._size)
6984 {
70- if (_size > kMaxSmallSize )
71- {
72- _dims_pointer = new int32_t [_size];
85+ if (_size <= kMaxSmallSize ) {
86+ // When the number of dimensions is small, copy the fixed array.
87+ dims_ = std::get<std::array<int32_t , kMaxSmallSize >>(other.dims_ );
88+ } else {
89+ // Otherwise, copy the dynamically allocated vector.
90+ dims_ = std::get<std::vector<int32_t >>(other.dims_ );
7391 }
74- std::memcpy (DimsData (), other.DimsData (), sizeof (int32_t ) * _size);
7592 }
93+ Shape (Shape &&other) = default ;
7694
7795 bool operator ==(const Shape &comp) const
7896 {
7997 return this ->_size == comp._size &&
8098 std::memcmp (DimsData (), comp.DimsData (), _size * sizeof (int32_t )) == 0 ;
8199 }
82100
83- ~Shape ()
84- {
85- if (_size > kMaxSmallSize )
86- {
87- delete[] _dims_pointer;
88- }
89- }
101+ ~Shape () = default ;
90102
103+ // Returns the number of dimensions.
91104 inline int32_t DimensionsCount () const { return _size; }
105+
106+ // Returns the dimension size at index i.
92107 inline int32_t Dims (int i) const
93108 {
94- assert (i >= 0 );
95- assert (i < _size);
96- return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
109+ assert (i >= 0 && i < _size);
110+ if (_size <= kMaxSmallSize ) {
111+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_)[i];
112+ } else {
113+ return std::get<std::vector<int32_t >>(dims_)[i];
114+ }
97115 }
116+
117+ // Sets the dimension at index i.
98118 inline void SetDim (int i, int32_t val)
99119 {
100- assert (i >= 0 );
101- assert (i < _size);
102- if (_size > kMaxSmallSize )
103- {
104- _dims_pointer [i] = val;
120+ assert (i >= 0 && i < _size );
121+ if (_size <= kMaxSmallSize ) {
122+ std::get<std::array< int32_t , kMaxSmallSize >>(dims_)[i] = val;
123+ } else {
124+ std::get<std::vector< int32_t >>(dims_) [i] = val;
105125 }
106- else
107- {
108- _dims[i] = val;
126+ }
127+
128+ // Returns a pointer to the dimension data (mutable).
129+ inline int32_t * DimsData ()
130+ {
131+ if (_size <= kMaxSmallSize ) {
132+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_).data ();
133+ } else {
134+ return std::get<std::vector<int32_t >>(dims_).data ();
109135 }
110136 }
111137
112- inline int32_t *DimsData () { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113- inline const int32_t *DimsData () const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114- // The caller must ensure that the shape is no bigger than 6D.
115- inline const int32_t *DimsDataUpTo6D () const { return _dims; }
138+ // Returns a pointer to the dimension data (const).
139+ inline const int32_t * DimsData () const
140+ {
141+ if (_size <= kMaxSmallSize ) {
142+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_).data ();
143+ } else {
144+ return std::get<std::vector<int32_t >>(dims_).data ();
145+ }
146+ }
116147
148+ // The caller must ensure that the shape is no larger than 6D.
149+ inline const int32_t * DimsDataUpTo6D () const {
150+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_).data ();
151+ }
152+
153+ // Resizes the shape to dimensions_count.
117154 inline void Resize (int dimensions_count)
118155 {
119- if (_size > kMaxSmallSize )
120- {
121- delete[] _dims_pointer;
122- }
123156 _size = dimensions_count;
124- if (dimensions_count > kMaxSmallSize )
125- {
126- _dims_pointer = new int32_t [dimensions_count];
157+ if (dimensions_count <= kMaxSmallSize ) {
158+ dims_ = std::array<int32_t , kMaxSmallSize >{};
159+ } else {
160+ dims_ = std::vector<int32_t >(dimensions_count);
127161 }
128162 }
129163
164+ // Replaces the current shape with a new one defined by dimensions_count and dims_data.
130165 inline void ReplaceWith (int dimensions_count, const int32_t *dims_data)
131166 {
132167 Resize (dimensions_count);
133- int32_t *dst_dims = DimsData ();
134- std::memcpy (dst_dims, dims_data, dimensions_count * sizeof (int32_t ));
168+ std::memcpy (DimsData (), dims_data, dimensions_count * sizeof (int32_t ));
135169 }
136170
171+ // Replaces the current shape with another shape.
137172 inline void ReplaceWith (const Shape &other)
138173 {
139174 ReplaceWith (other.DimensionsCount (), other.DimsData ());
140175 }
141176
177+ // Replaces the current shape with another shape using move semantics.
142178 inline void ReplaceWith (Shape &&other)
143179 {
144- Resize (0 );
145180 std::swap (_size, other._size );
146- if (_size <= kMaxSmallSize )
147- std::copy (other._dims , other._dims + kMaxSmallSize , _dims);
148- else
149- _dims_pointer = other._dims_pointer ;
181+ dims_ = std::move (other.dims_ );
150182 }
151183
152- template <typename T> inline void BuildFrom (const T &src_iterable)
184+ // Builds the shape from an iterable sequence.
185+ template <typename Iterable>
186+ inline void BuildFrom (const Iterable &src_iterable)
153187 {
154- const int dimensions_count = std::distance (src_iterable.begin (), src_iterable.end ());
188+ const int dimensions_count = static_cast < int >( std::distance (src_iterable.begin (), src_iterable.end () ));
155189 Resize (dimensions_count);
156- int32_t * data = DimsData ();
157- for (auto && it : src_iterable)
190+ int32_t * data = DimsData ();
191+ for (auto it = src_iterable. begin (); it != src_iterable. end (); ++it )
158192 {
159- *data = it;
160- ++data;
193+ *data++ = static_cast <int32_t >(*it);
161194 }
162195 }
163196
164- // This will probably be factored out. Old code made substantial use of 4-D
165- // shapes, and so this function is used to extend smaller shapes. Note that
166- // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167- // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168- // inputs should already be 4-D, so this function should not be needed.
197+ // Returns the total count of elements, that is the size when flattened into a
198+ // vector.
169199 inline static Shape ExtendedShape (int new_shape_size, const Shape &shape)
170200 {
171201 return Shape (new_shape_size, shape, 1 );
172202 }
173203
204+ // Overload for initializer list building.
174205 inline void BuildFrom (const std::initializer_list<int > init_list)
175206 {
176207 BuildFrom<const std::initializer_list<int >>(init_list);
177208 }
178209
179- // Returns the total count of elements, that is the size when flattened into a
180- // vector.
210+ // Returns the total count of elements (flattened size).
181211 inline int FlatSize () const
182212 {
183213 int buffer_size = 1 ;
184- const int * dims_data = DimsData ();
214+ const int * dims_data = DimsData ();
185215 for (int i = 0 ; i < _size; i++)
186216 {
187- const int dim = dims_data[i];
188- buffer_size *= dim;
217+ buffer_size *= dims_data[i];
189218 }
190219 return buffer_size;
191220 }
@@ -206,17 +235,17 @@ class Shape
206235 {
207236 SetDim (i, pad_value);
208237 }
209- std::memcpy (DimsData () + size_increase, shape.DimsData (),
210- sizeof (int32_t ) * shape.DimensionsCount ());
238+ std::memcpy (DimsData () + size_increase, shape.DimsData (), sizeof (int32_t ) * shape.DimensionsCount ());
211239 }
212240
213241 int32_t _size;
214- union {
215- int32_t _dims[kMaxSmallSize ];
216- int32_t *_dims_pointer{nullptr };
217- };
242+ // Internal storage: use std::array for shapes with dimensions up to kMaxSmallSize,
243+ // and std::vector for larger shapes.
244+ std::variant<std::array<int32_t , kMaxSmallSize >, std::vector<int32_t >> dims_;
218245};
219246
247+ // Utility functions below.
248+
220249inline int MatchingDim (const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
221250 [[maybe_unused]] int index2)
222251{
@@ -232,7 +261,10 @@ int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &s
232261 return MatchingDim (shape1, index1, args...);
233262}
234263
235- inline Shape GetShape (const std::vector<int32_t > &data) { return Shape (data.size (), data.data ()); }
264+ inline Shape GetShape (const std::vector<int32_t > &data)
265+ {
266+ return Shape (static_cast <int >(data.size ()), data.data ());
267+ }
236268
237269inline int Offset (const Shape &shape, int i0, int i1, int i2, int i3)
238270{
@@ -278,8 +310,7 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
278310 return flat_size;
279311}
280312
281- // Flat size calculation, checking that dimensions match with one or more other
282- // arrays.
313+ // Flat size calculation, checking that dimensions match with one or more other shapes.
283314template <typename ... Ts> inline bool checkMatching (const Shape &shape, Ts... check_shapes)
284315{
285316 auto match = [&shape](const Shape &s) -> bool {
0 commit comments