Skip to content

Implement &&= and ||= in ChapelBase via new primitive #27102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 71 additions & 53 deletions compiler/AST/TransformLogicalShortCircuit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,66 +31,84 @@
#include "IfExpr.h"
#include "stmt.h"

static size_t opLen = strlen("&&");

bool TransformLogicalShortCircuit::enterCallExpr(CallExpr* call)
{
if (shouldTransformCall(call)) {
UnresolvedSymExpr* expr = toUnresolvedSymExpr(call->baseExpr);

bool isLogicalAnd = strncmp(expr->unresolved, "&&", opLen) == 0;
bool isLogicalOr = strncmp(expr->unresolved, "||", opLen) == 0;
INT_ASSERT(isLogicalAnd || isLogicalOr);
bool isCompoundAssign = strncmp(expr->unresolved + opLen, "=", 1) == 0;

SET_LINENO(call);

Expr* left = call->get(1);
Expr* right = call->get(2);
VarSymbol* lvar = newTemp();

VarSymbol* eMsg = NULL;
IfExpr* ife = NULL;

left->remove();
right->remove();

lvar->addFlag(FLAG_MAYBE_PARAM);

if (isLogicalAnd) {
eMsg = new_StringSymbol("cannot promote short-circuiting && operator");
ife = new IfExpr(new CallExpr("isTrue", lvar),
new CallExpr("isTrue", right), new SymExpr(gFalse));
} else {
eMsg = new_StringSymbol("cannot promote short-circuiting || operator");
ife = new IfExpr(new CallExpr("isTrue", lvar), new SymExpr(gTrue),
new CallExpr("isTrue", right));
}

//
// By handling conditionals in pre-order, we do not need to store an
// insertion point. The top-level conditional will be inserted before
// the original statement, and any nested conditionals will be stored
// within the IfExprs blocks which are still before the original
// statement.
//
Expr* stmt = call->getStmtExpr();
stmt->insertBefore(new DefExpr(lvar));
if (isCompoundAssign) {
stmt->insertBefore(
new CallExpr(PRIM_MOVE, lvar, new CallExpr(PRIM_ADDR_OF, left)));
} else {
stmt->insertBefore(new CallExpr(PRIM_MOVE, lvar, left));
}
stmt->insertBefore(new CondStmt(new CallExpr("_cond_invalid", lvar),
new CallExpr("compilerError", eMsg)));
if (isCompoundAssign) {
stmt->insertAfter(new CallExpr("=", lvar, ife));
call->replace(new SymExpr(lvar));
} else {
call->replace(ife);
}

left->accept(this);
ife->accept(this);
}
return true;
}

bool TransformLogicalShortCircuit::shouldTransformCall(CallExpr* call)
{
// Lowering of LoopExprs will handle short-circuits itself
if (call->primitive == 0 && isLoopExpr(call->parentExpr) == false)
{
if (UnresolvedSymExpr* expr = toUnresolvedSymExpr(call->baseExpr))
{
bool isLogicalAnd = strcmp(expr->unresolved, "&&") == 0;
bool isLogicalOr = strcmp(expr->unresolved, "||") == 0;

if (isLogicalAnd || isLogicalOr)
{
SET_LINENO(call);

Expr* left = call->get(1);
Expr* right = call->get(2);
VarSymbol* lvar = newTemp();

VarSymbol* eMsg = NULL;
IfExpr* ife = NULL;

left->remove();
right->remove();

lvar->addFlag(FLAG_MAYBE_PARAM);

if (isLogicalAnd)
{
eMsg = new_StringSymbol("cannot promote short-circuiting && operator");
ife = new IfExpr(new CallExpr("isTrue", lvar),
new CallExpr("isTrue", right),
new SymExpr(gFalse));
}
else
{
eMsg = new_StringSymbol("cannot promote short-circuiting || operator");
ife = new IfExpr(new CallExpr("isTrue", lvar),
new SymExpr(gTrue),
new CallExpr("isTrue", right));
}

//
// By handling conditionals in pre-order, we do not need to store an
// insertion point. The top-level conditional will be inserted before
// the original statement, and any nested conditionals will be stored
// within the IfExprs blocks which are still before the original
// statement.
//
Expr* stmt = call->getStmtExpr();
stmt->insertBefore(new DefExpr(lvar));
stmt->insertBefore(new CallExpr(PRIM_MOVE, lvar, left));
stmt->insertBefore(new CondStmt(new CallExpr("_cond_invalid", lvar),
new CallExpr("compilerError", eMsg)));

call->replace(ife);

left->accept(this);
ife->accept(this);
}
bool isLogicalAnd = strncmp(expr->unresolved, "&&", opLen) == 0;
bool isLogicalOr = strncmp(expr->unresolved, "||", opLen) == 0;

return isLogicalAnd || isLogicalOr;
}
}
return true;
return false;
}
23 changes: 13 additions & 10 deletions compiler/AST/astutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,19 @@ void addUse(Map<Symbol*,Vec<SymExpr*>*>& useMap, SymExpr* use) {
// op= function call (such as before inlining)
//
bool isOpEqualPrim(CallExpr* call) {
if (call->isPrimitive(PRIM_ADD_ASSIGN) ||
call->isPrimitive(PRIM_SUBTRACT_ASSIGN) ||
call->isPrimitive(PRIM_MULT_ASSIGN) ||
call->isPrimitive(PRIM_DIV_ASSIGN) ||
call->isPrimitive(PRIM_MOD_ASSIGN) ||
call->isPrimitive(PRIM_LSH_ASSIGN) ||
call->isPrimitive(PRIM_RSH_ASSIGN) ||
call->isPrimitive(PRIM_AND_ASSIGN) ||
call->isPrimitive(PRIM_OR_ASSIGN) ||
call->isPrimitive(PRIM_XOR_ASSIGN)) {
if (call->isPrimitive(PRIM_ADD_ASSIGN) ||
call->isPrimitive(PRIM_SUBTRACT_ASSIGN) ||
call->isPrimitive(PRIM_MULT_ASSIGN) ||
call->isPrimitive(PRIM_DIV_ASSIGN) ||
call->isPrimitive(PRIM_MOD_ASSIGN) ||
call->isPrimitive(PRIM_LSH_ASSIGN) ||
call->isPrimitive(PRIM_RSH_ASSIGN) ||
call->isPrimitive(PRIM_AND_ASSIGN) ||
call->isPrimitive(PRIM_OR_ASSIGN) ||
call->isPrimitive(PRIM_XOR_ASSIGN) ||
call->isPrimitive(PRIM_LOGICALAND_ASSIGN) ||
call->isPrimitive(PRIM_LOGICALOR_ASSIGN)
) {
return true;
}
//otherwise false
Expand Down
2 changes: 2 additions & 0 deletions compiler/AST/checkAST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ void checkPrimitives()
case PRIM_AND_ASSIGN:
case PRIM_OR_ASSIGN:
case PRIM_XOR_ASSIGN:
case PRIM_LOGICALAND_ASSIGN:
case PRIM_LOGICALOR_ASSIGN:
case PRIM_MIN:
case PRIM_MAX:
case PRIM_SETCID:
Expand Down
2 changes: 2 additions & 0 deletions compiler/AST/primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ initPrimitive() {
prim_def(PRIM_AND_ASSIGN, "&=", returnInfoVoid, true);
prim_def(PRIM_OR_ASSIGN, "|=", returnInfoVoid, true);
prim_def(PRIM_XOR_ASSIGN, "^=", returnInfoVoid, true);
prim_def(PRIM_LOGICALAND_ASSIGN, "&&=", returnInfoVoid, true);
prim_def(PRIM_LOGICALOR_ASSIGN, "||=", returnInfoVoid, true);
prim_def(PRIM_REDUCE_ASSIGN, "reduce=", returnInfoVoid, true);

prim_def(PRIM_MIN, "_min", returnInfoFirst);
Expand Down
6 changes: 6 additions & 0 deletions compiler/codegen/cg-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5235,6 +5235,12 @@ DEFINE_PRIM(OR_ASSIGN) {
DEFINE_PRIM(XOR_ASSIGN) {
codegenOpAssign(call->get(1), call->get(2), " ^= ", codegenXor);
}
DEFINE_PRIM(LOGICALAND_ASSIGN) {
codegenOpAssign(call->get(1), call->get(2), " &&= ", codegenLogicalAnd);
}
DEFINE_PRIM(LOGICALOR_ASSIGN) {
codegenOpAssign(call->get(1), call->get(2), " ||= ", codegenLogicalOr);
}
DEFINE_PRIM(POW) {
ret = codegenCallExpr("pow", call->get(1), call->get(2));
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/include/TransformLogicalShortCircuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TransformLogicalShortCircuit final : public AstVisitorTraverse

// Transform performed pre-order
bool enterCallExpr (CallExpr* node) override;

static bool shouldTransformCall(CallExpr* node);
};

#endif
2 changes: 2 additions & 0 deletions compiler/optimizations/copyPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ static bool isUse(SymExpr* se)
case PRIM_AND_ASSIGN:
case PRIM_OR_ASSIGN:
case PRIM_XOR_ASSIGN:
case PRIM_LOGICALAND_ASSIGN:
case PRIM_LOGICALOR_ASSIGN:
if (isFirstActual)
{
return false;
Expand Down
2 changes: 2 additions & 0 deletions compiler/optimizations/loopInvariantCodeMotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ static bool isLoopInvariantPrimitive(PrimitiveOp* primitiveOp)
case PRIM_AND_ASSIGN:
case PRIM_OR_ASSIGN:
case PRIM_XOR_ASSIGN:
case PRIM_LOGICALAND_ASSIGN:
case PRIM_LOGICALOR_ASSIGN:

case PRIM_MIN:
case PRIM_MAX:
Expand Down
2 changes: 2 additions & 0 deletions compiler/optimizations/optimizeOnClauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ classifyPrimitive(CallExpr *call) {
case PRIM_AND_ASSIGN:
case PRIM_OR_ASSIGN:
case PRIM_XOR_ASSIGN:
case PRIM_LOGICALAND_ASSIGN:
case PRIM_LOGICALOR_ASSIGN:
if (isCallExpr(call->get(2))) { // callExprs checked in calling function
// Not necessarily true, but we return true because
// the callExpr will be checked in the calling function
Expand Down
26 changes: 0 additions & 26 deletions compiler/passes/convert-uast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2156,28 +2156,6 @@ struct Converter final : UastConverter {
return new CallExpr(PRIM_TO_NILABLE_CLASS_CHECKED, expr);
}

Expr* convertLogicalAndAssign(const uast::OpCall* node) {
if (node->op() != USTR("&&=")) return nullptr;

astlocMarker markAstLoc(node->id());

INT_ASSERT(node->numActuals() == 2);
Expr* lhs = convertAST(node->actual(0));
Expr* rhs = convertAST(node->actual(1));
return buildLAndAssignment(lhs, rhs);
}

Expr* convertLogicalOrAssign(const uast::OpCall* node) {
if (node->op() != USTR("||=")) return nullptr;

astlocMarker markAstLoc(node->id());

INT_ASSERT(node->numActuals() == 2);
Expr* lhs = convertAST(node->actual(0));
Expr* rhs = convertAST(node->actual(1));
return buildLOrAssignment(lhs, rhs);
}

Expr* convertTupleAssign(const uast::OpCall* node) {
if (node->op() != USTR("=") || node->numActuals() < 1
|| !node->actual(0)->isTuple()) return nullptr;
Expand Down Expand Up @@ -2226,10 +2204,6 @@ struct Converter final : UastConverter {
ret = conv;
} else if (auto conv = convertToNilableChecked(node)) {
ret = conv;
} else if (auto conv = convertLogicalAndAssign(node)) {
ret = conv;
} else if (auto conv = convertLogicalOrAssign(node)) {
ret = conv;
} else if (auto conv = convertTupleAssign(node)) {
ret = conv;
} else if (node->op() == USTR("align")) {
Expand Down
12 changes: 4 additions & 8 deletions compiler/passes/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,14 +675,10 @@ static void transformLogicalShortCircuit() {

// Collect the distinct stmts that contain logical AND/OR expressions
for_alive_in_Vec(CallExpr, call, gCallExprs) {
if (call->primitive == 0) {
if (UnresolvedSymExpr* expr = toUnresolvedSymExpr(call->baseExpr)) {
if (strcmp(expr->unresolved, "&&") == 0 ||
strcmp(expr->unresolved, "||") == 0) {
// Don't normalize lifetime constraint clauses
if (isInLifetimeClause(call) == false)
stmts.insert(call->getStmtExpr());
}
if (TransformLogicalShortCircuit::shouldTransformCall(call)) {
// Don't normalize lifetime constraint clauses
if (isInLifetimeClause(call) == false) {
stmts.insert(call->getStmtExpr());
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions frontend/include/chpl/uast/prim-ops-list.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ PRIMITIVE_G(RSH_ASSIGN, ">>=")
PRIMITIVE_G(AND_ASSIGN, "&=")
PRIMITIVE_G(OR_ASSIGN, "|=")
PRIMITIVE_G(XOR_ASSIGN, "^=")
PRIMITIVE_G(LOGICALAND_ASSIGN, "&&=")
PRIMITIVE_G(LOGICALOR_ASSIGN, "||=")
PRIMITIVE_R(REDUCE_ASSIGN, "reduce=")

PRIMITIVE_G(MIN, "_min")
Expand Down
Loading
Loading