Skip to content

Dyno: Type-convert select-statements #27095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 179 additions & 9 deletions compiler/passes/convert-typed-uast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ struct TConverter final : UastConverter {

// Create a new temporary. The type used for it must be supplied.
Symbol* makeNewTemp(const types::QualifiedType& qt, bool insertDef=true);
Symbol* makeNewTemp(::Qualifier qual, Type* t, bool insertDef=true);

// Store 'e' in a temporary if it does not already refer to one. The type
// for the temporary must be provided and cannot easily be retrieved from
Expand Down Expand Up @@ -1004,6 +1005,9 @@ struct TConverter final : UastConverter {
bool enter(const ExternBlock* node, RV& rv);
void exit(const ExternBlock* node, RV& rv);

bool enter(const Select* node, RV& rv);
void exit(const Select* node, RV& rv);

bool enter(const AstNode* node, RV& rv);
void exit(const AstNode* node, RV& rv);
};
Expand Down Expand Up @@ -1167,7 +1171,9 @@ void TConverter::convertFunctionsToConvert() {
}

for (auto pair : v) {
convertFunction(pair.first);
if (fns.find(pair.first) == fns.end()) {
convertFunction(pair.first);
}
}

// Create 'chpl_gen_main()' as well as an empty 'main()' function if needed.
Expand Down Expand Up @@ -2803,28 +2809,40 @@ TConverter::defaultValueForType(const types::Type* t,
}

Symbol*
TConverter::makeNewTemp(const types::QualifiedType& qt, bool insertDef) {
TConverter::makeNewTemp(::Qualifier qual, Type* t, bool insertDef) {
auto ret = newTemp();
ret->addFlag(FLAG_EXPR_TEMP);
ret->qual = convertQualifier(qt.kind());
ret->type = convertType(qt.type());
ret->qual = qual;
ret->type = t;

if (insertDef) insertStmt(new DefExpr(ret));

return ret;
}

Symbol*
TConverter::makeNewTemp(const types::QualifiedType& qt, bool insertDef) {
return makeNewTemp(convertQualifier(qt.kind()), convertType(qt.type()), insertDef);
}

SymExpr*
TConverter::storeInTempIfNeeded(Expr* e, const types::QualifiedType& qt) {
if (SymExpr* se = toSymExpr(e)) {
return se;
}

// Prevent default-constructed 'QualifiedType' from slipping in.
INT_ASSERT(qt.hasTypePtr());
Symbol* t = nullptr;

if (qt.isUnknown()) {
t = makeNewTemp(e->qualType().getQual(), e->typeInfo());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is that there are at least some paths through ->qualType() which seem to invoke the old resolver which is why I decided to only allow temps created with the chpl::types::QualifiedType. If we could figure out some way to translate without invoking the old resolver this could be done.

E.g., I know there are some primitives where the types are determined via the old resolver. I have not done due diligence but there could be more.

Maybe we hold off on adding this sort of thing if we can and if we do make sure via sanity checks/crashes that it can't actually invoke the old resolver.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unaware that this was the case, so I'll look into using QualifiedType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks, maybe consider adding a comment or something so that people in the future can catch on? I thought I added one but it was probably in a bad spot or too long winded.

} else {
// Prevent default-constructed 'QualifiedType' from slipping in.
INT_ASSERT(qt.hasTypePtr());

// otherwise, store the value in a temp
t = makeNewTemp(qt);
}

// otherwise, store the value in a temp
auto t = makeNewTemp(qt);
insertStmt(new CallExpr(PRIM_MOVE, t, e));
return new SymExpr(t);
}
Expand Down Expand Up @@ -4050,7 +4068,8 @@ void TConverter::ActualConverter::convertActual(const FormalActual& fa) {
INT_ASSERT(astActual);

// Convert the actual and leave.
std::get<Expr*>(slot) = tc_->convertExpr(astActual, rv_);
auto actualExpr = tc_->convertExpr(astActual, rv_);
std::get<Expr*>(slot) = tc_->storeInTempIfNeeded(actualExpr, {});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always need to have the actuals be temps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The uAST here could be a Call*, and normalized production AST requires a temporary.

Copy link
Contributor

@dlongnecke-cray dlongnecke-cray Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, thanks! Maybe we could have storeInTempIfNeeded be more sensitive so it can avoid that if possible, because currently it just always stores in a temp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the TODO, I think it would be good to have the callsites supply the types unless we're sure we're not going to accidentally invoke the old resolver.

So just get the type from convertExpr for now?


return;
}
Expand Down Expand Up @@ -5071,6 +5090,157 @@ bool TConverter::enter(const ExternBlock* node, RV& rv) {

void TConverter::exit(const ExternBlock* node, RV& rv) {}

static SymExpr* makeCaseCond(TConverter& tc, TConverter::RV& rv,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should normalize either TConverter* or TConverter& for static helper functions so that they're consistent, I don't really have a preference here but usually just do TConverter* cause then you can pass in this. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I prefer to use references when possible, but for consistency I'll change this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with using references, I'd just want to go and change the other areas to use them too 😄 is all. So feel free to keep and I'll add a TODO.

const uast::When* when,
SymExpr* selectExpr,
const uast::AstNode* cs) {
auto re = rv.byAst(cs);

// Grab the '==' ResolvedFunction
const AssociatedAction* action = nullptr;
for (const auto& a : re.associatedActions()) {
if (a.action() == AssociatedAction::COMPARE) {
action = &a;
break;
} else {
INT_ASSERT(false);
}
}
auto cmp = action->fn();

// TODO: create a wrapper for this kind of thing
const ResolvedFunction* rf = nullptr;
if (tc.paramElideCallOrNull(cmp, re.poiScope(), &rf)) {
INT_ASSERT(false && "Should not have param elided initializer call!");
} else if (!rf) {
return toSymExpr(TC_PLACEHOLDER((&tc)));
}

auto fn = tc.findOrConvertFunction(rf);
INT_ASSERT(fn);

// TODO: is there a better way to represent the return type of '=='? As-is,
// it's a bit confusing for the case-expression's type in 'rv' to always be
// a bool.
Expr* caseExpr = tc.convertExpr(cs, rv);
caseExpr = tc.storeInTempIfNeeded(caseExpr, {});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the type from convertExpr and pass that to storeInTempIfNeeded here?


// TODO: handle case where passing to '==' has an associated action
auto call = new CallExpr(fn, selectExpr->copy(), caseExpr);

auto cond = tc.storeInTempIfNeeded(call, re.type());

return cond;
}

static SymExpr* getWhenCond(TConverter& tc, TConverter::RV& rv,
const uast::When* when,
SymExpr* selectExpr) {
if (when->numCaseExprs() == 1) {
auto cs = when->caseExpr(0);
return makeCaseCond(tc, rv, when, selectExpr, cs);
} else {
// Multiple cases should follow '||'-like short-circuiting, such that
// equality comparisons are not evaluated unless the preceding cases do
// not match.

VarSymbol* agg = new VarSymbol("_case_cond_agg", dtBool);
Copy link
Contributor

@dlongnecke-cray dlongnecke-cray Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a unnamed temp here instead? I ask because just the other day I ran into a bug where I decided to name a temp chpl_error, and that actually caused the backend to be unable to locate the runtime function chpl_error. I feel like just using unnamed temps wherever possible might eliminate any chance of name conflicts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a case where I'd prefer to use the name to make logged AST easier to read. A quick search suggests this particular name isn't in use anywhere else.

tc.insertStmt(new CallExpr(PRIM_MOVE, agg, gFalse));
tc.insertStmt(new DefExpr(agg));

int count = 0;
for (auto cs : when->caseExprs()) {
// If it's param-true, 'getWhenCond' would not have been called
// If it's param-false, there's no need to generate any comparisons
if (rv.byAst(cs).type().isParam()) continue;

auto cond = makeCaseCond(tc, rv, when, selectExpr, cs);

auto thenBlock = new BlockStmt();
thenBlock->insertAtTail(new CallExpr(PRIM_MOVE, agg, gTrue));

auto elseBlock = new BlockStmt();
auto branch = new CondStmt(cond, thenBlock, elseBlock);
tc.insertStmt(branch);

// Push the else branch so that the next case check is inserted there
tc.pushBlock(elseBlock);
count += 1;
}

// pop the blocks we pushed, to get back to the when-stmt level
for (int i = 0; i < count; i++) {
tc.popBlock();
}

return new SymExpr(agg);
}
}

bool TConverter::enter(const Select* node, RV& rv) {
// TODO:
// - test case-exprs where the '==' operators have associated actions
// (e.g. in-intent on record argument)

//Note: out-of-order otherwise is an error addressed by post-parse-checks

types::QualifiedType selectQT;
auto selectExpr = convertExpr(node->expr(), rv, &selectQT);
auto selectSym = storeInTempIfNeeded(selectExpr, selectQT);

int count = 0;
for (auto when : node->whenStmts()) {
if (when->isOtherwise()) {
when->body()->traverse(rv);
} else {
bool anyParamTrue = false;
bool allParamFalse = true;
for(auto cs : when->caseExprs()) {
auto qt = rv.byAst(cs).type();
if (!qt.isParamFalse()) {
allParamFalse = false;
}

if (qt.isParamTrue()) {
anyParamTrue = true;
}
}

// Note: this differs from traditional behavior in production, where each
// comparison was made.
if (anyParamTrue) {
when->body()->traverse(rv);

// Nothing else can match after this, so we're done
break;
} else if (!allParamFalse) {
auto cond = getWhenCond(*this, rv, when, selectSym);

auto thenBlock = new BlockStmt();
pushBlock(thenBlock);
when->body()->traverse(rv);
popBlock();

auto elseBlock = new BlockStmt();
auto branch = new CondStmt(cond, thenBlock, elseBlock);
insertStmt(branch);

pushBlock(elseBlock);
count += 1;
}
}
}

// For every when-stmt, pop the else-conditions off
for (int i = 0; i < count; i++) {
popBlock();
}

return false;
}

void TConverter::exit(const Select* node, RV& rv) {}


bool TConverter::enter(const AstNode* node, RV& rv) {
TC_DEBUGF(this, "enter ast %s %s\n", node->id().str().c_str(), asttags::tagToString(node->tag()));
Expand Down
2 changes: 2 additions & 0 deletions compiler/passes/scopeResolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,8 @@ static void handleForallGoto(ForallStmt* forall, GotoStmt* gs) {

static void resolveGotoLabels() {
forv_Vec(GotoStmt, gs, gGotoStmts) {
if (gs->parentSymbol->hasFlag(FLAG_RESOLVED_EARLY)) continue;

SET_LINENO(gs);

Stmt* loop = NULL;
Expand Down
82 changes: 82 additions & 0 deletions test/frontend/TestSelect.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

use Print;

proc value() {
return 7;
}

proc test(arg: int) {
var x = 6;
select arg {
when 1 do println(1);
when 2 do println(4);
when 3, 4, 5 do println(9);
when x do println(36);
when value() do println(49);
otherwise println(1234);
}
}

proc paramValue() param do return 7;

proc testParam(param p: int) {
var x = 5;
var y = 3;
select p {
when x do println(1234); // make sure values still work
when y, 1, 2 do println(1); // mixture of param and value
when 6 do println(42); // param only
when paramValue() do println(777); // param returned via procedure
otherwise do println(999); // otherwise
}
}

proc helper() {
println(42);
return 5;
}

proc valueRet(arg: int) {
select arg {
when 1 do return 1; // single-expr case
when 2 do return 2; // intentionally do not pass '2' to test AST
when 3, 4, 5 do return 5; // multi-expr case
otherwise do return 42; // otherwise
}
}

proc main() {
test(1);
test(2);
test(3);
test(4);
test(5);
test(6);
test(7);
test(8);
test(9);
test(10);

testParam(1);
testParam(2);
testParam(3);
testParam(4);
testParam(5);
testParam(6);
testParam(7);
testParam(8);
testParam(9);
testParam(10);

// Should see only one '42' printed
select helper() {
when 1 do println(0);
when 2 do println(0);
when 5 do println(5);
otherwise do println(999);
}

println(valueRet(1));
println(valueRet(5));
println(valueRet(99));
}
25 changes: 25 additions & 0 deletions test/frontend/TestSelect.good
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
1
4
9
9
9
36
49
1234
1234
1234
1
1
1
999
1234
42
777
999
999
999
42
5
1
5
42