Skip to content

Commit e1d9bd1

Browse files
committed
Store when-statement case type normally, store bool in AssociatedAction
Previously we were storing the boolean result of a when-statement's '==' operator in the ResolvedExpression for the case-expression. This was inconsistent with how we do things elsewhere, and caused complications in the typed converter. This commit adds a QualifiedType to AssociatedActions and stores the boolean result there. Storing this result is particularly important in the case where it is param-true or param-false. This commit also adds a 'getAction' method to ResolvedExpression to facilitate fetching the 'COMPARE' action. Signed-off-by: Ben Harshbarger <[email protected]>
1 parent 5f832c4 commit e1d9bd1

File tree

6 files changed

+80
-33
lines changed

6 files changed

+80
-33
lines changed

frontend/include/chpl/resolution/resolution-types.h

+25-5
Original file line numberDiff line numberDiff line change
@@ -2378,15 +2378,18 @@ class AssociatedAction {
23782378
Action action_;
23792379
const TypedFnSignature* fn_;
23802380
ID id_;
2381+
types::QualifiedType type_;
23812382

23822383
public:
2383-
AssociatedAction(Action action, const TypedFnSignature* fn, ID id)
2384-
: action_(action), fn_(fn), id_(id) {
2384+
AssociatedAction(Action action, const TypedFnSignature* fn, ID id,
2385+
types::QualifiedType type)
2386+
: action_(action), fn_(fn), id_(id), type_(type) {
23852387
}
23862388
bool operator==(const AssociatedAction& other) const {
23872389
return action_ == other.action_ &&
23882390
fn_ == other.fn_ &&
2389-
id_ == other.id_;
2391+
id_ == other.id_ &&
2392+
type_ == other.type_;
23902393
}
23912394
bool operator!=(const AssociatedAction& other) const {
23922395
return !(*this == other);
@@ -2400,9 +2403,12 @@ class AssociatedAction {
24002403
/** Return the ID is associated with the action */
24012404
const ID& id() const { return id_; }
24022405

2406+
const types::QualifiedType type() const { return type_; }
2407+
24032408
void mark(Context* context) const {
24042409
if (fn_ != nullptr) fn_->mark(context);
24052410
id_.mark(context);
2411+
type_.mark(context);
24062412
}
24072413

24082414
void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const;
@@ -2478,6 +2484,19 @@ class ResolvedExpression {
24782484
return associatedActions_;
24792485
}
24802486

2487+
// TODO: Expected to be a placeholder as we look towards updating the
2488+
// representation of associated actions.
2489+
std::optional<AssociatedAction> getAction(AssociatedAction::Action action) const {
2490+
// TODO: what if there are multiple instances of the same action?
2491+
auto it = std::find_if(associatedActions_.begin(), associatedActions_.end(),
2492+
[&](const AssociatedAction a) { return a.action() == action; });
2493+
if (it != associatedActions_.end()) {
2494+
return *it;
2495+
} else {
2496+
return {};
2497+
}
2498+
}
2499+
24812500
const ResolvedParamLoop* paramLoop() const {
24822501
return paramLoop_;
24832502
}
@@ -2505,8 +2524,9 @@ class ResolvedExpression {
25052524
/** add an associated function */
25062525
void addAssociatedAction(AssociatedAction::Action action,
25072526
const TypedFnSignature* fn,
2508-
ID id) {
2509-
associatedActions_.push_back(AssociatedAction(action, fn, id));
2527+
ID id,
2528+
types::QualifiedType type) {
2529+
associatedActions_.push_back(AssociatedAction(action, fn, id, type));
25102530
}
25112531

25122532
void setParamLoop(const ResolvedParamLoop* paramLoop) { paramLoop_ = paramLoop; }

frontend/lib/resolution/Resolver.cpp

+31-21
Original file line numberDiff line numberDiff line change
@@ -681,11 +681,11 @@ const types::Param* Resolver::determineWhenCaseValue(const uast::AstNode* ast, I
681681
/* hasQuestionArg */ false,
682682
/* isParenless */ false,
683683
actuals);
684-
auto c = resolveGeneratedCall(ast, &ci, &inScopes);
685-
c.noteResult(&caseResult, { { AssociatedAction::COMPARE, ast->id() } });
686684

685+
auto c = resolveGeneratedCall(ast, &ci, &inScopes);
687686
auto type = c.result.exprType();
688-
caseResult.setType(type);
687+
c.noteResult(&caseResult, { { AssociatedAction::COMPARE, ast->id(), type } });
688+
689689
return type.param();
690690
}
691691

@@ -2308,7 +2308,7 @@ bool Resolver::CallResultWrapper::noteResultWithoutError(
23082308
ResolvedExpression* r,
23092309
const uast::AstNode* astForContext,
23102310
const CallResolutionResult& result,
2311-
optional<ActionAndId> actionAndId) {
2311+
optional<ActionInfo> actionInfo) {
23122312
bool needsErrors = false;
23132313
bool markErroneous = false;
23142314

@@ -2339,7 +2339,7 @@ bool Resolver::CallResultWrapper::noteResultWithoutError(
23392339
}
23402340

23412341
if (!result.exprType().hasTypePtr() || markErroneous) {
2342-
if (!actionAndId && r) {
2342+
if (!actionInfo && r) {
23432343
// Only set the type to erroneous if we're handling an actual user call,
23442344
// and not an associated action.
23452345
r->setType(QualifiedType(r->type().kind(), ErroneousType::get(resolver.context)));
@@ -2350,12 +2350,22 @@ bool Resolver::CallResultWrapper::noteResultWithoutError(
23502350
// issued its own error, so we shouldn't emit a general error.
23512351
return !result.speciallyHandled() || needsErrors;
23522352
} else {
2353-
if (actionAndId) {
2353+
if (actionInfo) {
23542354
// save candidates as associated functions
2355-
for (auto& sig : result.mostSpecific()) {
2356-
if (sig && r) {
2357-
r->addAssociatedAction(std::get<0>(*actionAndId), sig.fn(),
2358-
std::get<1>(*actionAndId));
2355+
if (result.mostSpecific().isEmpty() && r) {
2356+
// Store an associated action that did not have a function.
2357+
// E.g., 'COMPARE' for a when-statement condition that is param-true
2358+
r->addAssociatedAction(actionInfo->action, nullptr,
2359+
actionInfo->id, actionInfo->type);
2360+
} else {
2361+
bool seen = false;
2362+
for (auto& sig : result.mostSpecific()) {
2363+
if (sig && r) {
2364+
CHPL_ASSERT(!seen);
2365+
seen = true;
2366+
r->addAssociatedAction(actionInfo->action, sig.fn(),
2367+
actionInfo->id, actionInfo->type);
2368+
}
23592369
}
23602370
}
23612371
} else if (r) {
@@ -2372,13 +2382,13 @@ bool Resolver::CallResultWrapper::noteResultWithoutError(
23722382

23732383

23742384
bool Resolver::CallResultWrapper::noteResultWithoutError(ResolvedExpression* r,
2375-
optional<ActionAndId> actionAndId) {
2376-
return noteResultWithoutError(*parent, r, astForContext, result, std::move(actionAndId));
2385+
optional<ActionInfo> actionInfo) {
2386+
return noteResultWithoutError(*parent, r, astForContext, result, std::move(actionInfo));
23772387
}
23782388

23792389
void Resolver::CallResultWrapper::noteResult(ResolvedExpression* r,
2380-
optional<ActionAndId> actionAndId) {
2381-
if (noteResultWithoutError(r, std::move(actionAndId))) {
2390+
optional<ActionInfo> actionInfo) {
2391+
if (noteResultWithoutError(r, std::move(actionInfo))) {
23822392
issueBasicError();
23832393
}
23842394
}
@@ -2414,9 +2424,9 @@ bool Resolver::CallResultWrapper::rerunCallAndPrintCandidates() {
24142424
}
24152425

24162426
void Resolver::CallResultWrapper::noteResultPrintCandidates(ResolvedExpression* r,
2417-
optional<ActionAndId> actionAndId) {
2427+
optional<ActionInfo> actionInfo) {
24182428
CHPL_ASSERT(!wasGeneratedCall || receiverType.isUnknown());
2419-
if (noteResultWithoutError(r, std::move(actionAndId))) {
2429+
if (noteResultWithoutError(r, std::move(actionInfo))) {
24202430
if (rerunCallAndPrintCandidates()) {
24212431
return;
24222432
}
@@ -2792,7 +2802,7 @@ bool Resolver::resolveSpecialNewCall(const Call* call) {
27922802

27932803
// note: the resolution machinery will get compiler generated candidates
27942804
auto c = resolveGeneratedCall(call, &ci, &inScopes);
2795-
optional<ActionAndId> action({ AssociatedAction::NEW_INIT, call->id() });
2805+
optional<ActionInfo> action({ AssociatedAction::NEW_INIT, call->id() });
27962806
c.noteResultPrintCandidates(&re, std::move(action));
27972807

27982808

@@ -3221,7 +3231,7 @@ bool Resolver::resolveSpecialKeywordCall(const Call* call) {
32213231

32223232
// Note: this issues errors from compilerError in the body of the
32233233
// domain builder.
3224-
optional<ActionAndId> action({ AssociatedAction::RUNTIME_TYPE, fnCall->id() });
3234+
optional<ActionInfo> action({ AssociatedAction::RUNTIME_TYPE, fnCall->id() });
32253235
bool needsErrors =
32263236
runResult.result().noteResultWithoutError(&r, std::move(action));
32273237

@@ -4717,8 +4727,8 @@ bool Resolver::enter(const uast::Manage* manage) {
47174727
}
47184728
}
47194729
CHPL_ASSERT(enterSig && exitSig);
4720-
rr.addAssociatedAction(AssociatedAction::ENTER_CONTEXT, enterSig, manage->id());
4721-
rr.addAssociatedAction(AssociatedAction::EXIT_CONTEXT, exitSig, manage->id());
4730+
rr.addAssociatedAction(AssociatedAction::ENTER_CONTEXT, enterSig, manage->id(), {});
4731+
rr.addAssociatedAction(AssociatedAction::EXIT_CONTEXT, exitSig, manage->id(), {});
47224732
}
47234733

47244734
enterScope(manage);
@@ -5228,7 +5238,7 @@ rerunCallInfoWithIteratorTag(ResolutionContext* rc,
52285238
if (!newC.mostSpecific().isEmpty()) {
52295239
for (auto sig : newC.mostSpecific()) {
52305240
if (!sig) continue;
5231-
r.addAssociatedAction(AssociatedAction::ITERATE, sig.fn(), call->id());
5241+
r.addAssociatedAction(AssociatedAction::ITERATE, sig.fn(), call->id(), {});
52325242
}
52335243

52345244
return newC;

frontend/lib/resolution/Resolver.h

+11-5
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ namespace resolution {
3737

3838
struct Resolver : BranchSensitiveVisitor<DefaultFrame> {
3939
// types used below
40-
using ActionAndId = std::tuple<AssociatedAction::Action, ID>;
40+
//using ActionInfo = std::tuple<AssociatedAction::Action, ID, types::QualifiedType>;
41+
struct ActionInfo {
42+
public:
43+
AssociatedAction::Action action;
44+
ID id;
45+
types::QualifiedType type;
46+
};
4147
using SubstitutionsMap = types::CompositeType::SubstitutionsMap;
4248
using ReceiverScopesVec = SimpleMethodLookupHelper::ReceiverScopesVec;
4349
using IgnoredExtraData = std::variant<std::monostate>;
@@ -527,17 +533,17 @@ struct Resolver : BranchSensitiveVisitor<DefaultFrame> {
527533
//
528534
// Instead, returns 'true' if an error needs to be issued.
529535
bool noteResultWithoutError(ResolvedExpression* r,
530-
optional<ActionAndId> associatedActionAndId = {});
536+
optional<ActionInfo> associatedActionInfo = {});
531537

532538
static bool noteResultWithoutError(Resolver& resolver,
533539
ResolvedExpression* r,
534540
const uast::AstNode* astForContext,
535541
const CallResolutionResult& c,
536-
optional<ActionAndId> associatedActionAndId = {});
542+
optional<ActionInfo> associatedActionInfo = {});
537543

538544
// Same as noteResultWithoutError, but also issues errors.
539545
void noteResult(ResolvedExpression* r,
540-
optional<ActionAndId> associatedActionAndId = {});
546+
optional<ActionInfo> associatedActionInfo = {});
541547

542548
// Issues a more specific error (listing rejected candidates) if possible.
543549
// To collect the candidates, re-runs the call. Returns true if an error
@@ -547,7 +553,7 @@ struct Resolver : BranchSensitiveVisitor<DefaultFrame> {
547553
// Like noteResult, except attempts to do more work to print fancier errors
548554
// (see rerunCallAndPrintCandidates).
549555
void noteResultPrintCandidates(ResolvedExpression* r,
550-
optional<ActionAndId> associatedActionAndId = {});
556+
optional<ActionInfo> associatedActionInfo = {});
551557
};
552558

553559
/* The resolver's wrapper of resolution::resolveGeneratedCall.

frontend/lib/resolution/VarScopeVisitor.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ static const types::Param* determineParamValue(const ResolvedExpression& rr) {
211211
}
212212

213213
const types::Param* VarScopeVisitor::determineWhenCaseValue(const uast::AstNode* ast, RV& extraData) {
214-
return determineParamValue(extraData.byAst(ast));
214+
if (auto action = extraData.byAst(ast).getAction(AssociatedAction::COMPARE)) {
215+
return action->type().param();
216+
}
217+
return nullptr;
215218
}
216219

217220
const types::Param* VarScopeVisitor::determineIfValue(const uast::AstNode* ast, RV& extraData) {

frontend/lib/resolution/resolution-types.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,10 @@ void AssociatedAction::stringify(std::ostream& ss,
13641364
ss << " id=";
13651365
id_.stringify(ss, stringKind);
13661366
}
1367+
if (!type_.isUnknown()) {
1368+
ss << " type=";
1369+
type_.stringify(ss, stringKind);
1370+
}
13671371
}
13681372

13691373
void ResolvedExpression::stringify(std::ostream& ss,

frontend/lib/resolution/return-type-inference.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,11 @@ void ReturnTypeInferrer::doExitScope(const uast::AstNode* node, RV& rv) {
499499
}
500500

501501
const types::Param* ReturnTypeInferrer::determineWhenCaseValue(const uast::AstNode* ast, RV& rv) {
502-
return rv.byAst(ast).type().param();
502+
if (auto action = rv.byAst(ast).getAction(AssociatedAction::COMPARE)) {
503+
return action->type().param();
504+
} else {
505+
return nullptr;
506+
}
503507
}
504508
const types::Param* ReturnTypeInferrer::determineIfValue(const uast::AstNode* ast, RV& rv) {
505509
return rv.byAst(ast).type().param();

0 commit comments

Comments
 (0)