Skip to content

Commit 167d50b

Browse files
authored
Dyno resolve rectangular array literals (#26832)
Add support for resolving rectangular array literals (including multidimensional) using module code calls in Dyno. Also includes: - Implementing a new `FlatteningArrayIterator` which yields elements of a multi-dimensional (or 1D) array as a flat list. This is used for Dyno array resolution and also replaces similar recursive logic in the converter. - Reverting module code workaround #26795 which appears to be no longer needed. - Dyno tests for rectangular array literals. Resolves Cray/chapel-private#7191. [reviewed by @DanilaFe, thanks!] Testing: - [x] dyno tests - [x] paratest - [x] expected spectests pass, including newly passing: - [x] `release/examples/spec/Arrays/adecl-2x2x3-literal` - [x] `release/examples/spec/Arrays/adecl-literal`
2 parents f8420c4 + d15d955 commit 167d50b

File tree

8 files changed

+271
-29
lines changed

8 files changed

+271
-29
lines changed

compiler/passes/convert-uast.cpp

+4-15
Original file line numberDiff line numberDiff line change
@@ -1711,23 +1711,12 @@ struct Converter final : UastConverter {
17111711

17121712
/// Array, Domain, Range, Tuple ///
17131713

1714-
void convertArrayRow(const uast::ArrayRow* node, CallExpr* actualList) {
1715-
for (auto expr : node->exprs()) {
1716-
if (expr->isArrayRow()) {
1717-
convertArrayRow(expr->toArrayRow(), actualList);
1718-
} else {
1719-
actualList->insertAtTail(convertAST(expr));
1720-
}
1721-
}
1722-
}
1723-
17241714
Expr* visit(const uast::Array* node) {
17251715
CallExpr* actualList = new CallExpr(PRIM_ACTUALS_LIST);
17261716
Expr* shapeList = nullptr;
17271717
bool isAssociativeList = false;
17281718

1729-
bool isNDArray = node->numExprs() >= 1 && node->expr(0)->isArrayRow();
1730-
if (!isNDArray) {
1719+
if (!node->isMultiDim()) {
17311720
for (auto expr : node->exprs()) {
17321721
bool hasConvertedThisIter = false;
17331722

@@ -1763,13 +1752,13 @@ struct Converter final : UastConverter {
17631752
}
17641753
shapeList = new CallExpr("_build_tuple", shapeActualList);
17651754

1766-
for (auto expr : node->exprs()) {
1767-
convertArrayRow(expr->toArrayRow(), actualList);
1755+
for (auto expr : node->flattenedExprs()) {
1756+
actualList->insertAtTail(convertAST(expr));
17681757
}
17691758
}
17701759

17711760
Expr* ret = nullptr;
1772-
if (!isNDArray) {
1761+
if (!node->isMultiDim()) {
17731762
INT_ASSERT(shapeList == nullptr);
17741763
if (isAssociativeList) {
17751764
ret = new CallExpr("chpl__buildAssociativeArrayExpr", actualList);

doc/util/nitpick_ignore

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ cpp:identifier Location
3636
cpp:identifier AstTag
3737
cpp:identifier uast::AstTag
3838
cpp:identifier AstList::const_iterator
39+
cpp:identifier AstListIt::value_type
40+
cpp:identifier AstListIt::difference_type
41+
cpp:identifier AstListIt::pointer
42+
cpp:identifier AstListIt::reference
3943
cpp:identifier size_t
4044
cpp:identifier detail
4145
cpp:identifier detail::PODUniqueString

frontend/include/chpl/uast/Array.h

+154-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#include "chpl/framework/Location.h"
2424
#include "chpl/uast/AstNode.h"
25+
#include "chpl/uast/ArrayRow.h"
26+
#include <iterator>
2527

2628
namespace chpl {
2729
namespace uast {
@@ -45,29 +47,34 @@ class Array final : public AstNode {
4547

4648
private:
4749
bool trailingComma_,
48-
associative_;
50+
associative_,
51+
isMultiDim_;
4952

5053
Array(AstList children, bool trailingComma, bool associative)
51-
: AstNode(asttags::Array, std::move(children)),
52-
trailingComma_(trailingComma),
53-
associative_(associative) {
54+
: AstNode(asttags::Array, std::move(children)),
55+
trailingComma_(trailingComma),
56+
associative_(associative) {
57+
isMultiDim_ = this->numExprs() > 0 && this->expr(0)->isArrayRow();
5458
}
5559

5660
void serializeInner(Serializer& ser) const override {
5761
ser.write(trailingComma_);
5862
ser.write(associative_);
63+
ser.write(isMultiDim_);
5964
}
6065

6166
explicit Array(Deserializer& des)
6267
: AstNode(asttags::Array, des) {
6368
trailingComma_ = des.read<bool>();
6469
associative_ = des.read<bool>();
70+
isMultiDim_ = des.read<bool>();
6571
}
6672

6773
bool contentsMatchInner(const AstNode* other) const override {
6874
const Array* rhs = other->toArray();
6975
return this->trailingComma_ == rhs->trailingComma_ &&
70-
this->associative_ == rhs->associative_;
76+
this->associative_ == rhs->associative_ &&
77+
this->isMultiDim_ == rhs->isMultiDim_;
7178
}
7279

7380
void markUniqueStringsInner(Context* context) const override {
@@ -92,8 +99,7 @@ class Array final : public AstNode {
9299
Return a way to iterate over the expressions of this array.
93100
*/
94101
AstListIteratorPair<AstNode> exprs() const {
95-
return AstListIteratorPair<AstNode>(children_.begin(),
96-
children_.end());
102+
return AstListIteratorPair<AstNode>(children_.begin(), children_.end());
97103
}
98104

99105
/**
@@ -110,6 +116,147 @@ class Array final : public AstNode {
110116
const AstNode* ast = this->child(i);
111117
return ast;
112118
}
119+
120+
/**
121+
Return whether this is a multi-dimensional array.
122+
*/
123+
bool isMultiDim() const {
124+
return this->isMultiDim_;
125+
}
126+
127+
/**
128+
* Return the shape of this multi-dim array, as a list of dimension lengths.
129+
*/
130+
std::vector<int> shape() const {
131+
CHPL_ASSERT(this->isMultiDim());
132+
std::vector<int> ret;
133+
ret.emplace_back(this->numExprs());
134+
auto cur = this->expr(0);
135+
while(cur->toArrayRow()) {
136+
ret.emplace_back(cur->toArrayRow()->numExprs());
137+
cur = cur->toArrayRow()->expr(0);
138+
}
139+
return ret;
140+
}
141+
142+
/**
143+
* An iterator that flattens a multi-dimensional array into a single list.
144+
*/
145+
class FlatteningArrayIterator {
146+
public:
147+
using AstListIt = AstListIterator<AstNode>;
148+
using iterator_category = std::forward_iterator_tag;
149+
using value_type = AstListIt::value_type;
150+
using difference_type = AstListIt::difference_type;
151+
using pointer = AstListIt::pointer;
152+
using reference = AstListIt::reference;
153+
154+
private:
155+
// Stack of current row iterator positions, one for each dimension. The
156+
// bottom iterates over the array itself, and the top iterates over a row of
157+
// innermost dimension.
158+
// Each entry is a pair of (current, end) iterators.
159+
llvm::SmallVector<std::pair<AstListIt, AstListIt>, 1> rowIterStack;
160+
161+
/*
162+
* Descend to the innermost array dimension, adding an iterator for each
163+
* dimension along the way.
164+
*/
165+
void descendDims() {
166+
CHPL_ASSERT(!rowIterStack.empty() && "should not be possible");
167+
while (auto row = (*rowIterStack.back().first)->toArrayRow()) {
168+
CHPL_ASSERT(row->numExprs() > 0 && "empty rows not supported");
169+
const auto exprs = row->exprs();
170+
this->rowIterStack.emplace_back(exprs.begin(), exprs.end());
171+
}
172+
}
173+
174+
FlatteningArrayIterator(AstListIt begin, AstListIt end) {
175+
rowIterStack.emplace_back(begin, end);
176+
}
177+
178+
static void assertNonEmptyArr(const Array* arr) {
179+
CHPL_ASSERT(arr->numExprs() > 0 && "empty arrays not supported");
180+
}
181+
182+
public:
183+
// Construct an iterator starting at the beginning of the array
184+
static FlatteningArrayIterator normal(const Array* iterand) {
185+
assertNonEmptyArr(iterand);
186+
FlatteningArrayIterator ret(iterand->exprs().begin(),
187+
iterand->exprs().end());
188+
ret.descendDims();
189+
return ret;
190+
}
191+
192+
// Construct an iterator starting at the end of the array
193+
static FlatteningArrayIterator end(const Array* iterand) {
194+
assertNonEmptyArr(iterand);
195+
return FlatteningArrayIterator(iterand->exprs().end(),
196+
iterand->exprs().end());
197+
}
198+
199+
bool operator==(const FlatteningArrayIterator rhs) const {
200+
// Should only be necessary to compare the innermost-dimension iterator
201+
// pairs.
202+
// If we add support for empty arrays/rows we'll have to compare (up to)
203+
// the entire stack, as multiple empty rows could have the same begin
204+
// and end iterators.
205+
return this->rowIterStack.back() == rhs.rowIterStack.back();
206+
}
207+
bool operator!=(const FlatteningArrayIterator rhs) const {
208+
return !(*this == rhs);
209+
}
210+
211+
const AstNode* operator*() const {
212+
return *this->rowIterStack.back().first;
213+
}
214+
const AstNode* operator->() const { return operator*(); }
215+
216+
FlatteningArrayIterator& operator++() {
217+
// Pop up the stack until we're either at the top level, or at a row we
218+
// haven't already gone through.
219+
while (++rowIterStack.back().first == rowIterStack.back().second) {
220+
// Special case: leave the top level array iterator on the stack
221+
// when it hits the end.
222+
if (rowIterStack.size() == 1) return *this;
223+
rowIterStack.pop_back();
224+
}
225+
226+
// We're in an unfinished row; continue iteration from the innermost
227+
// dimension under this row.
228+
descendDims();
229+
return *this;
230+
}
231+
232+
FlatteningArrayIterator operator++(int) {
233+
FlatteningArrayIterator tmp = *this;
234+
operator++();
235+
return tmp;
236+
}
237+
};
238+
239+
struct FlatteningArrayIteratorPair {
240+
FlatteningArrayIterator begin_;
241+
FlatteningArrayIterator end_;
242+
243+
FlatteningArrayIteratorPair(FlatteningArrayIterator begin,
244+
FlatteningArrayIterator end)
245+
: begin_(begin), end_(end) {}
246+
~FlatteningArrayIteratorPair() = default;
247+
248+
FlatteningArrayIterator begin() const { return begin_; }
249+
FlatteningArrayIterator end() const { return end_; }
250+
};
251+
252+
/**
253+
Return a way to iterate over the expressions of this array, transparently
254+
flattened into a single list if multi-dimensional.
255+
*/
256+
FlatteningArrayIteratorPair flattenedExprs() const {
257+
return FlatteningArrayIteratorPair(FlatteningArrayIterator::normal(this),
258+
FlatteningArrayIterator::end(this));
259+
}
113260
};
114261

115262

frontend/include/chpl/uast/ArrayRow.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "chpl/framework/Location.h"
2424
#include "chpl/uast/AstNode.h"
25+
#include <iterator>
2526

2627
namespace chpl {
2728
namespace uast {
@@ -77,8 +78,7 @@ class ArrayRow final : public AstNode {
7778
Return a way to iterate over the expressions of this array row.
7879
*/
7980
AstListIteratorPair<AstNode> exprs() const {
80-
return AstListIteratorPair<AstNode>(children_.begin(),
81-
children_.end());
81+
return AstListIteratorPair<AstNode>(children_.begin(), children_.end());
8282
}
8383

8484
/**

frontend/lib/resolution/Resolver.cpp

+54-1
Original file line numberDiff line numberDiff line change
@@ -3035,7 +3035,7 @@ shouldSkipCallResolution(Resolver* rv, const uast::AstNode* callLike,
30353035
qt.isRef() == false) {
30363036
// don't skip because it could be initialized with 'out' intent,
30373037
// but not for non-out formals because they can't be split-initialized.
3038-
} else if (actualAst->isTypeQuery() && ci.calledType().isType()) {
3038+
} else if (actualAst && actualAst->isTypeQuery() && ci.calledType().isType()) {
30393039
// don't skip for type queries in type constructors
30403040
} else {
30413041
if (qt.isParam() && qt.param() == nullptr) {
@@ -4883,6 +4883,59 @@ void Resolver::exit(const Range* range) {
48834883
}
48844884
}
48854885

4886+
bool Resolver::enter(const uast::Array* decl) {
4887+
return true;
4888+
}
4889+
void Resolver::exit(const uast::Array* decl) {
4890+
if (scopeResolveOnly) {
4891+
return;
4892+
}
4893+
4894+
ResolvedExpression& r = byPostorder.byAst(decl);
4895+
4896+
// Resolve call to appropriate array builder proc
4897+
const char* arrayBuilderProc;
4898+
std::vector<CallInfoActual> actuals;
4899+
std::vector<const uast::AstNode*> actualAsts;
4900+
if (!decl->isMultiDim()) {
4901+
arrayBuilderProc = "chpl__buildArrayExpr";
4902+
} else {
4903+
arrayBuilderProc = "chpl__buildNDArrayExpr";
4904+
4905+
// Get shape arg
4906+
std::vector<QualifiedType> shapeTupleElts;
4907+
for (auto dim : decl->shape()) {
4908+
shapeTupleElts.push_back(QualifiedType::makeParamInt(context, dim));
4909+
}
4910+
auto shapeTupleType = TupleType::getQualifiedTuple(context, shapeTupleElts);
4911+
actualAsts.push_back(nullptr);
4912+
actuals.emplace_back(
4913+
QualifiedType(QualifiedType::CONST_VAR, shapeTupleType),
4914+
UniqueString());
4915+
}
4916+
4917+
// Add element args
4918+
for (auto expr : decl->flattenedExprs()) {
4919+
actualAsts.push_back(expr);
4920+
actuals.emplace_back(byPostorder.byAst(expr).type(), UniqueString());
4921+
}
4922+
4923+
auto ci = CallInfo(/* name */ UniqueString::get(context, arrayBuilderProc),
4924+
/* calledType */ QualifiedType(),
4925+
/* isMethodCall */ false,
4926+
/* hasQuestionArg */ false,
4927+
/* isParenless */ false, actuals);
4928+
if (shouldSkipCallResolution(this, decl, actualAsts, ci)) {
4929+
r.setType(QualifiedType());
4930+
return;
4931+
}
4932+
auto scope = currentScope();
4933+
auto inScopes = CallScopeInfo::forNormalCall(scope, poiScope);
4934+
auto c = resolveGeneratedCall(decl, &ci, &inScopes);
4935+
4936+
c.noteResult(&r);
4937+
}
4938+
48864939
bool Resolver::enter(const uast::Domain* decl) {
48874940
return true;
48884941
}

frontend/lib/resolution/Resolver.h

+3
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,9 @@ struct Resolver : BranchSensitiveVisitor<DefaultFrame> {
753753
bool enter(const uast::Range* decl);
754754
void exit(const uast::Range* decl);
755755

756+
bool enter(const uast::Array* decl);
757+
void exit(const uast::Array* decl);
758+
756759
bool enter(const uast::Domain* decl);
757760
void exit(const uast::Domain* decl);
758761

0 commit comments

Comments
 (0)