1919#define __NNFW_CKER_SHAPE_H__
2020
2121#include < algorithm>
22- #include < cstring >
22+ #include < array >
2323#include < cassert>
24+ #include < cstring>
25+ #include < iterator>
26+ #include < variant>
2427#include < vector>
2528
2629namespace nnfw
@@ -35,18 +38,28 @@ class Shape
3538 // larger shapes are separately allocated.
3639 static constexpr int kMaxSmallSize = 6 ;
3740
41+ // Delete copy assignment operator.
3842 Shape &operator =(Shape const &) = delete ;
3943
40- Shape () : _size(0 ) {}
44+ // Default constructor: initializes an empty shape (size = 0) with small storage.
45+ Shape () : _size(0 ), dims_(std::array<int32_t , kMaxSmallSize >{}) {}
4146
47+ // Constructor that takes a dimension count.
48+ // If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
49+ // Otherwise, it uses a dynamic vector.
4250 explicit Shape (int dimensions_count) : _size(dimensions_count)
4351 {
44- if (dimensions_count > kMaxSmallSize )
52+ if (dimensions_count <= kMaxSmallSize )
4553 {
46- _dims_pointer = new int32_t [dimensions_count];
54+ dims_ = std::array<int32_t , kMaxSmallSize >{};
55+ }
56+ else
57+ {
58+ dims_ = std::vector<int32_t >(dimensions_count);
4759 }
4860 }
4961
62+ // Constructor that creates a shape of given size and fills all dimensions with "value".
5063 Shape (int shape_size, int32_t value) : _size(0 )
5164 {
5265 Resize (shape_size);
@@ -56,136 +69,171 @@ class Shape
5669 }
5770 }
5871
72+ // Constructor that creates a shape from an array of dimension data.
5973 Shape (int dimensions_count, const int32_t *dims_data) : _size(0 )
6074 {
6175 ReplaceWith (dimensions_count, dims_data);
6276 }
6377
78+ // Initializer list constructor.
79+ // Marked explicit to avoid unintended overload resolution.
6480 Shape (const std::initializer_list<int > init_list) : _size(0 ) { BuildFrom (init_list); }
6581
66- // Avoid using this constructor. We should be able to delete it when C++17
67- // rolls out.
68- Shape (Shape const &other) : _size(other.DimensionsCount())
82+ // Copy constructor
83+ Shape (const Shape &other) : _size(other._size)
6984 {
70- if (_size > kMaxSmallSize )
85+ if (_size <= kMaxSmallSize )
86+ {
87+ // When the number of dimensions is small, copy the fixed array.
88+ dims_ = std::get<std::array<int32_t , kMaxSmallSize >>(other.dims_ );
89+ }
90+ else
7191 {
72- _dims_pointer = new int32_t [_size];
92+ // Otherwise, copy the dynamically allocated vector.
93+ dims_ = std::get<std::vector<int32_t >>(other.dims_ );
7394 }
74- std::memcpy (DimsData (), other.DimsData (), sizeof (int32_t ) * _size);
7595 }
96+ Shape (Shape &&other) = default ;
7697
7798 bool operator ==(const Shape &comp) const
7899 {
79100 return this ->_size == comp._size &&
80101 std::memcmp (DimsData (), comp.DimsData (), _size * sizeof (int32_t )) == 0 ;
81102 }
82103
83- ~Shape ()
104+ ~Shape () = default ;
105+
106+ // Returns the number of dimensions.
107+ inline int32_t DimensionsCount () const { return _size; }
108+
109+ // Returns the dimension size at index i.
110+ inline int32_t Dims (int i) const
84111 {
85- if (_size > kMaxSmallSize )
112+ assert (i >= 0 && i < _size);
113+ if (_size <= kMaxSmallSize )
86114 {
87- delete[] _dims_pointer;
115+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_)[i];
116+ }
117+ else
118+ {
119+ return std::get<std::vector<int32_t >>(dims_)[i];
88120 }
89121 }
90122
91- inline int32_t DimensionsCount () const { return _size; }
92- inline int32_t Dims (int i) const
123+ // Sets the dimension at index i.
124+ inline void SetDim (int i, int32_t val)
93125 {
94- assert (i >= 0 );
95- assert (i < _size);
96- return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
126+ assert (i >= 0 && i < _size);
127+ if (_size <= kMaxSmallSize )
128+ {
129+ std::get<std::array<int32_t , kMaxSmallSize >>(dims_)[i] = val;
130+ }
131+ else
132+ {
133+ std::get<std::vector<int32_t >>(dims_)[i] = val;
134+ }
97135 }
98- inline void SetDim (int i, int32_t val)
136+
137+ // Returns a pointer to the dimension data (mutable).
138+ inline int32_t *DimsData ()
99139 {
100- assert (i >= 0 );
101- assert (i < _size);
102- if (_size > kMaxSmallSize )
140+ if (_size <= kMaxSmallSize )
103141 {
104- _dims_pointer[i] = val ;
142+ return std::get<std::array< int32_t , kMaxSmallSize >>(dims_). data () ;
105143 }
106144 else
107145 {
108- _dims[i] = val ;
146+ return std::get<std::vector< int32_t >>(dims_). data () ;
109147 }
110148 }
111149
112- inline int32_t *DimsData () { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113- inline const int32_t *DimsData () const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114- // The caller must ensure that the shape is no bigger than 6D.
115- inline const int32_t *DimsDataUpTo6D () const { return _dims; }
150+ // Returns a pointer to the dimension data (const).
151+ inline const int32_t *DimsData () const
152+ {
153+ if (_size <= kMaxSmallSize )
154+ {
155+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_).data ();
156+ }
157+ else
158+ {
159+ return std::get<std::vector<int32_t >>(dims_).data ();
160+ }
161+ }
162+
163+ // The caller must ensure that the shape is no larger than 6D.
164+ inline const int32_t *DimsDataUpTo6D () const
165+ {
166+ return std::get<std::array<int32_t , kMaxSmallSize >>(dims_).data ();
167+ }
116168
169+ // Resizes the shape to dimensions_count.
117170 inline void Resize (int dimensions_count)
118171 {
119- if (_size > kMaxSmallSize )
172+ _size = dimensions_count;
173+ if (dimensions_count <= kMaxSmallSize )
120174 {
121- delete[] _dims_pointer ;
175+ dims_ = std::array< int32_t , kMaxSmallSize >{} ;
122176 }
123- _size = dimensions_count;
124- if (dimensions_count > kMaxSmallSize )
177+ else
125178 {
126- _dims_pointer = new int32_t [ dimensions_count] ;
179+ dims_ = std::vector< int32_t >( dimensions_count) ;
127180 }
128181 }
129182
183+ // Replaces the current shape with a new one defined by dimensions_count and dims_data.
130184 inline void ReplaceWith (int dimensions_count, const int32_t *dims_data)
131185 {
132186 Resize (dimensions_count);
133- int32_t *dst_dims = DimsData ();
134- std::memcpy (dst_dims, dims_data, dimensions_count * sizeof (int32_t ));
187+ std::memcpy (DimsData (), dims_data, dimensions_count * sizeof (int32_t ));
135188 }
136189
190+ // Replaces the current shape with another shape.
137191 inline void ReplaceWith (const Shape &other)
138192 {
139193 ReplaceWith (other.DimensionsCount (), other.DimsData ());
140194 }
141195
196+ // Replaces the current shape with another shape using move semantics.
142197 inline void ReplaceWith (Shape &&other)
143198 {
144- Resize (0 );
145199 std::swap (_size, other._size );
146- if (_size <= kMaxSmallSize )
147- std::copy (other._dims , other._dims + kMaxSmallSize , _dims);
148- else
149- _dims_pointer = other._dims_pointer ;
200+ dims_ = std::move (other.dims_ );
150201 }
151202
152- template <typename T> inline void BuildFrom (const T &src_iterable)
203+ // Builds the shape from an iterable sequence.
204+ template <typename Iterable> inline void BuildFrom (const Iterable &src_iterable)
153205 {
154- const int dimensions_count = std::distance (src_iterable.begin (), src_iterable.end ());
206+ const int dimensions_count =
207+ static_cast <int >(std::distance (src_iterable.begin (), src_iterable.end ()));
155208 Resize (dimensions_count);
156209 int32_t *data = DimsData ();
157- for (auto && it : src_iterable)
210+ for (auto it = src_iterable. begin (); it != src_iterable. end (); ++it )
158211 {
159- *data = it;
160- ++data;
212+ *data++ = static_cast <int32_t >(*it);
161213 }
162214 }
163215
164- // This will probably be factored out. Old code made substantial use of 4-D
165- // shapes, and so this function is used to extend smaller shapes. Note that
166- // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167- // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168- // inputs should already be 4-D, so this function should not be needed.
216+ // Returns the total count of elements, that is the size when flattened into a
217+ // vector.
169218 inline static Shape ExtendedShape (int new_shape_size, const Shape &shape)
170219 {
171220 return Shape (new_shape_size, shape, 1 );
172221 }
173222
223+ // Overload for initializer list building.
174224 inline void BuildFrom (const std::initializer_list<int > init_list)
175225 {
176226 BuildFrom<const std::initializer_list<int >>(init_list);
177227 }
178228
179- // Returns the total count of elements, that is the size when flattened into a
180- // vector.
229+ // Returns the total count of elements (flattened size).
181230 inline int FlatSize () const
182231 {
183232 int buffer_size = 1 ;
184233 const int *dims_data = DimsData ();
185234 for (int i = 0 ; i < _size; i++)
186235 {
187- const int dim = dims_data[i];
188- buffer_size *= dim;
236+ buffer_size *= dims_data[i];
189237 }
190238 return buffer_size;
191239 }
@@ -211,12 +259,13 @@ class Shape
211259 }
212260
213261 int32_t _size;
214- union {
215- int32_t _dims[kMaxSmallSize ];
216- int32_t *_dims_pointer{nullptr };
217- };
262+ // Internal storage: use std::array for shapes with dimensions up to kMaxSmallSize,
263+ // and std::vector for larger shapes.
264+ std::variant<std::array<int32_t , kMaxSmallSize >, std::vector<int32_t >> dims_;
218265};
219266
267+ // Utility functions below.
268+
220269inline int MatchingDim (const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
221270 [[maybe_unused]] int index2)
222271{
@@ -232,7 +281,10 @@ int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &s
232281 return MatchingDim (shape1, index1, args...);
233282}
234283
235- inline Shape GetShape (const std::vector<int32_t > &data) { return Shape (data.size (), data.data ()); }
284+ inline Shape GetShape (const std::vector<int32_t > &data)
285+ {
286+ return Shape (static_cast <int >(data.size ()), data.data ());
287+ }
236288
237289inline int Offset (const Shape &shape, int i0, int i1, int i2, int i3)
238290{
@@ -278,8 +330,7 @@ inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
278330 return flat_size;
279331}
280332
281- // Flat size calculation, checking that dimensions match with one or more other
282- // arrays.
333+ // Flat size calculation, checking that dimensions match with one or more other shapes.
283334template <typename ... Ts> inline bool checkMatching (const Shape &shape, Ts... check_shapes)
284335{
285336 auto match = [&shape](const Shape &s) -> bool {
0 commit comments