22
22
23
23
#include " chpl/framework/Location.h"
24
24
#include " chpl/uast/AstNode.h"
25
+ #include " chpl/uast/ArrayRow.h"
26
+ #include < iterator>
25
27
26
28
namespace chpl {
27
29
namespace uast {
@@ -45,29 +47,34 @@ class Array final : public AstNode {
45
47
46
48
private:
47
49
bool trailingComma_,
48
- associative_;
50
+ associative_,
51
+ isMultiDim_;
49
52
50
53
Array (AstList children, bool trailingComma, bool associative)
51
- : AstNode(asttags::Array, std::move(children)),
52
- trailingComma_ (trailingComma),
53
- associative_(associative) {
54
+ : AstNode(asttags::Array, std::move(children)),
55
+ trailingComma_ (trailingComma),
56
+ associative_(associative) {
57
+ isMultiDim_ = this ->numExprs () > 0 && this ->expr (0 )->isArrayRow ();
54
58
}
55
59
56
60
void serializeInner (Serializer& ser) const override {
57
61
ser.write (trailingComma_);
58
62
ser.write (associative_);
63
+ ser.write (isMultiDim_);
59
64
}
60
65
61
66
explicit Array (Deserializer& des)
62
67
: AstNode(asttags::Array, des) {
63
68
trailingComma_ = des.read <bool >();
64
69
associative_ = des.read <bool >();
70
+ isMultiDim_ = des.read <bool >();
65
71
}
66
72
67
73
bool contentsMatchInner (const AstNode* other) const override {
68
74
const Array* rhs = other->toArray ();
69
75
return this ->trailingComma_ == rhs->trailingComma_ &&
70
- this ->associative_ == rhs->associative_ ;
76
+ this ->associative_ == rhs->associative_ &&
77
+ this ->isMultiDim_ == rhs->isMultiDim_ ;
71
78
}
72
79
73
80
void markUniqueStringsInner (Context* context) const override {
@@ -92,8 +99,7 @@ class Array final : public AstNode {
92
99
Return a way to iterate over the expressions of this array.
93
100
*/
94
101
AstListIteratorPair<AstNode> exprs () const {
95
- return AstListIteratorPair<AstNode>(children_.begin (),
96
- children_.end ());
102
+ return AstListIteratorPair<AstNode>(children_.begin (), children_.end ());
97
103
}
98
104
99
105
/* *
@@ -110,6 +116,147 @@ class Array final : public AstNode {
110
116
const AstNode* ast = this ->child (i);
111
117
return ast;
112
118
}
119
+
120
+ /* *
121
+ Return whether this is a multi-dimensional array.
122
+ */
123
+ bool isMultiDim () const {
124
+ return this ->isMultiDim_ ;
125
+ }
126
+
127
+ /* *
128
+ * Return the shape of this multi-dim array, as a list of dimension lengths.
129
+ */
130
+ std::vector<int > shape () const {
131
+ CHPL_ASSERT (this ->isMultiDim ());
132
+ std::vector<int > ret;
133
+ ret.emplace_back (this ->numExprs ());
134
+ auto cur = this ->expr (0 );
135
+ while (cur->toArrayRow ()) {
136
+ ret.emplace_back (cur->toArrayRow ()->numExprs ());
137
+ cur = cur->toArrayRow ()->expr (0 );
138
+ }
139
+ return ret;
140
+ }
141
+
142
+ /* *
143
+ * An iterator that flattens a multi-dimensional array into a single list.
144
+ */
145
+ class FlatteningArrayIterator {
146
+ public:
147
+ using AstListIt = AstListIterator<AstNode>;
148
+ using iterator_category = std::forward_iterator_tag;
149
+ using value_type = AstListIt::value_type;
150
+ using difference_type = AstListIt::difference_type;
151
+ using pointer = AstListIt::pointer;
152
+ using reference = AstListIt::reference;
153
+
154
+ private:
155
+ // Stack of current row iterator positions, one for each dimension. The
156
+ // bottom iterates over the array itself, and the top iterates over a row of
157
+ // innermost dimension.
158
+ // Each entry is a pair of (current, end) iterators.
159
+ llvm::SmallVector<std::pair<AstListIt, AstListIt>, 1 > rowIterStack;
160
+
161
+ /*
162
+ * Descend to the innermost array dimension, adding an iterator for each
163
+ * dimension along the way.
164
+ */
165
+ void descendDims () {
166
+ CHPL_ASSERT (!rowIterStack.empty () && " should not be possible" );
167
+ while (auto row = (*rowIterStack.back ().first )->toArrayRow ()) {
168
+ CHPL_ASSERT (row->numExprs () > 0 && " empty rows not supported" );
169
+ const auto exprs = row->exprs ();
170
+ this ->rowIterStack .emplace_back (exprs.begin (), exprs.end ());
171
+ }
172
+ }
173
+
174
+ FlatteningArrayIterator (AstListIt begin, AstListIt end) {
175
+ rowIterStack.emplace_back (begin, end);
176
+ }
177
+
178
+ static void assertNonEmptyArr (const Array* arr) {
179
+ CHPL_ASSERT (arr->numExprs () > 0 && " empty arrays not supported" );
180
+ }
181
+
182
+ public:
183
+ // Construct an iterator starting at the beginning of the array
184
+ static FlatteningArrayIterator normal (const Array* iterand) {
185
+ assertNonEmptyArr (iterand);
186
+ FlatteningArrayIterator ret (iterand->exprs ().begin (),
187
+ iterand->exprs ().end ());
188
+ ret.descendDims ();
189
+ return ret;
190
+ }
191
+
192
+ // Construct an iterator starting at the end of the array
193
+ static FlatteningArrayIterator end (const Array* iterand) {
194
+ assertNonEmptyArr (iterand);
195
+ return FlatteningArrayIterator (iterand->exprs ().end (),
196
+ iterand->exprs ().end ());
197
+ }
198
+
199
+ bool operator ==(const FlatteningArrayIterator rhs) const {
200
+ // Should only be necessary to compare the innermost-dimension iterator
201
+ // pairs.
202
+ // If we add support for empty arrays/rows we'll have to compare (up to)
203
+ // the entire stack, as multiple empty rows could have the same begin
204
+ // and end iterators.
205
+ return this ->rowIterStack .back () == rhs.rowIterStack .back ();
206
+ }
207
+ bool operator !=(const FlatteningArrayIterator rhs) const {
208
+ return !(*this == rhs);
209
+ }
210
+
211
+ const AstNode* operator *() const {
212
+ return *this ->rowIterStack .back ().first ;
213
+ }
214
+ const AstNode* operator ->() const { return operator *(); }
215
+
216
+ FlatteningArrayIterator& operator ++() {
217
+ // Pop up the stack until we're either at the top level, or at a row we
218
+ // haven't already gone through.
219
+ while (++rowIterStack.back ().first == rowIterStack.back ().second ) {
220
+ // Special case: leave the top level array iterator on the stack
221
+ // when it hits the end.
222
+ if (rowIterStack.size () == 1 ) return *this ;
223
+ rowIterStack.pop_back ();
224
+ }
225
+
226
+ // We're in an unfinished row; continue iteration from the innermost
227
+ // dimension under this row.
228
+ descendDims ();
229
+ return *this ;
230
+ }
231
+
232
+ FlatteningArrayIterator operator ++(int ) {
233
+ FlatteningArrayIterator tmp = *this ;
234
+ operator ++();
235
+ return tmp;
236
+ }
237
+ };
238
+
239
+ struct FlatteningArrayIteratorPair {
240
+ FlatteningArrayIterator begin_;
241
+ FlatteningArrayIterator end_;
242
+
243
+ FlatteningArrayIteratorPair (FlatteningArrayIterator begin,
244
+ FlatteningArrayIterator end)
245
+ : begin_(begin), end_(end) {}
246
+ ~FlatteningArrayIteratorPair () = default ;
247
+
248
+ FlatteningArrayIterator begin () const { return begin_; }
249
+ FlatteningArrayIterator end () const { return end_; }
250
+ };
251
+
252
+ /* *
253
+ Return a way to iterate over the expressions of this array, transparently
254
+ flattened into a single list if multi-dimensional.
255
+ */
256
+ FlatteningArrayIteratorPair flattenedExprs () const {
257
+ return FlatteningArrayIteratorPair (FlatteningArrayIterator::normal (this ),
258
+ FlatteningArrayIterator::end (this ));
259
+ }
113
260
};
114
261
115
262
0 commit comments