Skip to content

simplify: implement standard copyprop #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 272 additions & 14 deletions src/main/scala/ir/transforms/Simp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -694,37 +694,26 @@ def copypropTransform(
val t = util.PerformanceTimer(s"simplify ${p.name} (${p.blocks.size} blocks)")
// SimplifyLogger.info(s"${p.name} ExprComplexity ${ExprComplexity()(p)}")
// val result = solver.solveProc(p, true).withDefaultValue(dom.bot)
val result = CopyProp.DSACopyProp(p, procFrames, funcEntries, constRead)
val solve = t.checkPoint("Solve CopyProp")

if (result.nonEmpty) {
val vis = Simplify(CopyProp.toResult(result))
visit_proc(vis, p)
}
AlgebraicSimplifications(p)
OffsetProp.transform(p)

visit_proc(CopyProp.BlockyProp(), p)
simplifyCFG(p)
transforms.fixupGuards(p)
transforms.removeDuplicateGuard(p.blocks.toSeq)

val gvis = GuardVisitor(ir.eval.SimplifyValidation.validate)
visit_proc(gvis, p)
AssumeConditionSimplifications(p)

val xf = t.checkPoint("transform")
// SimplifyLogger.info(s" ${p.name} after transform expr complexity ${ExprComplexity()(p)}")

visit_proc(CleanupAssignments(), p)
t.checkPoint("redundant assignments")
// SimplifyLogger.info(s" ${p.name} after dead var cleanup expr complexity ${ExprComplexity()(p)}")

AlgebraicSimplifications(p)
AssumeConditionSimplifications(p)

AlgebraicSimplifications(p)
// SimplifyLogger.info(s" ${p.name} after simp expr complexity ${ExprComplexity()(p)}")
val sipm = t.checkPoint("algebraic simp")

// SimplifyLogger.info("[!] Simplify :: RemoveSlices")
removeSlices(p)
ir.eval.cleanupSimplify(p)
AlgebraicSimplifications(p)
Expand Down Expand Up @@ -1172,6 +1161,275 @@ object getProcFrame {

}

object OffsetProp {

/*
* Copyprop for any expression of fitting into the structure
* bvadd(variable, constant)
*
* This is sufficient to propagate branch conditions through.
*/

// None, None -> Top
// Some(v), None -> v
// Some(v), Some(Lit) -> v + Lit
// None, Some(Lit) -> Lit
type Value = (Option[Variable], Option[BitVecLiteral])

def joinValue(l: Value, r: Value) = {
(l, r) match {
case ((None, None), _) => (None, None)
case (_, (None, None)) => (None, None)
case (l, r) if l != r => (None, None)
case (l, r) => l
}
}

class CopyProp() {
val st = mutable.Map[Variable, Value]()
var giveUp = false
val lastUpdate = mutable.Map[Block, Int]()
var stSequenceNo = 1

def findOff(v: Variable, c: BitVecLiteral): BitVecLiteral | Variable | BinaryExpr = find(v) match {
case lc: BitVecLiteral => ir.eval.BitVectorEval.smt_bvadd(lc, c)
case lv: Variable => BinaryExpr(BVADD, lv, c)
case BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) =>
BinaryExpr(BVADD, l, ir.eval.BitVectorEval.smt_bvadd(r, c))
case _ => throw Exception("Unexpected expression structure created by find() at some point")
}

def find(v: Variable): BitVecLiteral | Variable | BinaryExpr = {
st.get(v) match {
case None => v
case Some((None, None)) => v
case Some((None, Some(c))) => c
case Some((Some(v), None)) => find(v)
case Some((Some(v), Some(c))) => findOff(v, c)
}
}

def joinState(lhs: Variable, rhs: Expr) = {
specJoinState(lhs, rhs) match {
case Some((l, r)) => {
if (st.contains(l) && st(l) != r) {
stSequenceNo += 1
}
st(l) = r
}
case _ => ()
}
}

def specJoinState(lhs: Variable, rhs: Expr): Option[(Variable, Value)] = {
rhs match {
case e @ BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) if (!st.contains(lhs)) =>
Some(lhs -> (Some(l), Some(r)))
case e @ BinaryExpr(BVADD, l: Variable, r: BitVecLiteral) if findOff(l, r) == find(lhs) => None
case v: Variable if (!st.contains(lhs)) => Some(lhs -> (Some(v), None))
case v: BitVecLiteral if (!st.contains(lhs)) => Some(lhs -> (None, Some(v)))
case v: Variable if (find(lhs) == find(v)) => None
case c: BitVecLiteral if (find(lhs) != c) => Some(lhs -> (None, None))
case _ => Some(lhs -> (None, None))
}
}

def clob(v: Variable) = {
st(v) = (None, None)
}

def transfer(s: Statement) = s match {
case LocalAssign(l: Variable, r: Variable, _) => joinState(l, r)
case LocalAssign(l: Variable, r: Literal, _) => joinState(l, r)
case LocalAssign(l: Variable, r @ BinaryExpr(BVADD, _: Variable, _: BitVecLiteral), _) => joinState(l, r)
case LocalAssign(l: Variable, _, _) => clob(l)
// case s: SimulAssign => s.assignments.flatMap {
// case (l: Variable, r: Variable) => specJoinState(l, r).toSeq
// case (l: Variable, r) => Seq(l -> None)
// }.foreach {
// case (l, r) => st(l) = r
// }
case a: Assign => {
// memoryload and DirectCall
a.assignees.foreach(clob)
}
case _: MemoryStore => ()
case _: NOP => ()
case _: Assert => ()
case _: Assume => ()
case i: IndirectCall => giveUp = true
}

def analyse(p: Procedure): Map[Variable, Expr] = {
reversePostOrder(p)
val worklist = mutable.PriorityQueue[Block]()(Ordering.by(_.rpoOrder))
worklist.addAll(p.entryBlock)
while (worklist.nonEmpty && !giveUp) {
val b = worklist.dequeue()
val seq = lastUpdate.get(b).getOrElse(0)

b.statements.foreach(transfer)

if (stSequenceNo != seq || seq == 0) {
lastUpdate(b) = stSequenceNo
worklist.addAll(b.nextBlocks)
}
}

val res: Map[Variable, Variable | Literal | BinaryExpr] =
if giveUp then Map()
else
st.collect {
case (v, (Some(v2), None)) => v -> find(v2)
case (v, (None, Some(c))) => v -> c
case (v, (Some(v2), Some(c))) => v -> findOff(v2, c)
}.toMap

res
}

}

def transform(p: Procedure) = {
val solver = CopyProp()
val res = solver.analyse(p)

class SubstExprs(subst: Map[Variable, Expr]) extends CILVisitor {
override def vexpr(e: Expr) = {
Substitute(subst.get)(e) match {
case Some(n) => ChangeTo(n)
case _ => SkipChildren()
}
}
}
if (res.nonEmpty) {
visit_proc(SubstExprs(res), p)
}
}
}

object MinCopyProp {

// None -> Top
type Value = Option[Variable | Literal]

class CopyProp() {
val st = mutable.Map[Variable, Value]()
val lastUpdate = mutable.Map[Block, Int]()
var stSequenceNo = 1
var giveUp = false

def find(v: Variable): Literal | Variable = {
var search: Variable = v

boundary {
while (st.contains(search)) {
st(search) match {
case None => break(search)
case Some(c: Literal) => break(c)
case Some(v: Variable) =>
search = v
}
}
search
}
}

def specJoinState(lhs: Variable, rhs: Variable | Literal): Option[(Variable, Value)] = {
rhs match {
case v: Variable if (!st.contains(lhs)) => Some(lhs -> Some(v))
case v: Literal if (!st.contains(lhs)) => Some(lhs -> Some(v))
case v: Variable if (find(lhs) != find(v)) => Some(lhs -> None)
case c: Literal if (find(lhs) != c) => Some(lhs -> None)
case _ => None
}
}

def joinState(lhs: Variable, rhs: Variable | Literal) = {
specJoinState(lhs, rhs) match {
case Some((l, r)) => {
if (st.contains(l) && st(l) != r) {
stSequenceNo += 1
}
st(l) = r
}
case _ => ()
}
}

def clob(v: Variable) = {
st(v) = None
}

def transfer(s: Statement) = s match {
case LocalAssign(l: Variable, r: Variable, _) => joinState(l, r)
case LocalAssign(l: Variable, r: Literal, _) => joinState(l, r)
case LocalAssign(l: Variable, _, _) => clob(l)
// case s: SimulAssign => s.assignments.flatMap {
// case (l: Variable, r: Variable) => specJoinState(l, r).toSeq
// case (l: Variable, r) => Seq(l -> None)
// }.foreach {
// case (l, r) => st(l) = r
// }
case a: Assign => {
// memoryload and DirectCall
a.assignees.foreach(clob)
}
case _: MemoryStore => ()
case _: NOP => ()
case _: Assert => ()
case _: Assume => ()
case i: IndirectCall => giveUp = true
}

def analyse(p: Procedure): Map[Variable, Variable | Literal] = {
reversePostOrder(p)
val worklist = mutable.PriorityQueue[Block]()(Ordering.by(_.rpoOrder))
worklist.addAll(p.entryBlock)
while (worklist.nonEmpty && !giveUp) {
val b = worklist.dequeue()
val seq = lastUpdate.get(b).getOrElse(0)

b.statements.foreach(transfer)

if (stSequenceNo != seq || seq == 0) {
lastUpdate(b) = stSequenceNo
worklist.addAll(b.nextBlocks)
}
}

val res: Map[Variable, Variable | Literal] =
if giveUp then Map()
else
st.collect {
case (v, Some(r: Variable)) => v -> find(r)
case (v, Some(c: Literal)) => v -> c
}.toMap

res
}
}

def transform(p: Procedure) = {
val solver = CopyProp()
val res = solver.analyse(p)

class SubstExprs(subst: Map[Variable, Expr]) extends CILVisitor {
override def vexpr(e: Expr) = {
Substitute(subst.get)(e) match {
case Some(n) => ChangeTo(n)
case _ => SkipChildren()
}
}
}
if (res.nonEmpty) {
visit_proc(SubstExprs(res), p)
}

}

}

object CopyProp {

class BlockyProp(trivialOnly: Boolean = true, var transform: Boolean = true) extends CILVisitor {
Expand Down
Loading