Skip to content

Commit d3fb96b

Browse files
committed
feat: maintain program counter and assert at beginning of blocks
this makes a fairly significant change to the GTIRB loader by inserting PC increment statements after instructions and checking the PC at the entry of blocks. this is done by changing the loader to maintain PC and branchtaken assignments, and inserting a default increment if there is no branch. this surely has bugs
1 parent 4ba88e6 commit d3fb96b

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

src/main/scala/translating/GTIRBLoader.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,18 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
2626
private var blockCount = 0
2727
private var loadCounter = 0
2828

29-
private val opcodeSize = 4
29+
val opcodeSize = 4
3030

31-
def visitBlock(blockUUID: ByteString, blockCountIn: Int, blockAddress: Option[BigInt]): ArrayBuffer[Statement] = {
31+
def visitBlock(
32+
blockUUID: ByteString,
33+
blockCountIn: Int,
34+
blockAddress: Option[BigInt]
35+
): ArrayBuffer[immutable.Seq[Statement]] = {
3236
blockCount = blockCountIn
3337
instructionCount = 0
3438
val instructions = parserMap(Base64.getEncoder.encodeToString(blockUUID.toByteArray))
3539

36-
val statements: ArrayBuffer[Statement] = ArrayBuffer()
40+
val statements: ArrayBuffer[immutable.Seq[Statement]] = ArrayBuffer()
3741

3842
for (instsem <- instructions) {
3943
constMap.clear
@@ -43,17 +47,19 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
4347
case InsnSemantics.Error(op, err) => {
4448
val message = s"$op ${err.replace("\n", " :: ")}"
4549
Logger.warn(s"Program contains lifter unsupported opcode: $message")
46-
statements.append(Assert(FalseLiteral, Some(s"Lifter error: $message")))
50+
statements.append(immutable.Seq(Assert(FalseLiteral, Some(s"Lifter error: $message"))))
4751
instructionCount += 1
4852
}
49-
case InsnSemantics.Result(instruction) => {
50-
for ((s, i) <- instruction.zipWithIndex) {
53+
case InsnSemantics.Result(aslstmts) => {
54+
var stmts = immutable.LinearSeq[Statement]()
55+
56+
for ((s, i) <- aslstmts.zipWithIndex) {
5157
val label = blockAddress.map { (a: BigInt) =>
5258
val instructionAddress = a + (opcodeSize * instructionCount)
5359
instructionAddress.toString + "_" + i
5460
}
5561

56-
statements.appendAll(try {
62+
stmts = stmts ++ (try {
5763
visitStmt(s, label)
5864
} catch {
5965
case e => {
@@ -62,6 +68,8 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
6268
}
6369
})
6470
}
71+
72+
statements.append(stmts)
6573
instructionCount += 1
6674
}
6775
}
@@ -176,7 +184,7 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
176184
}
177185

178186
if (condition.isDefined) {
179-
Some(TempIf(condition.get, thenStmts, elseStmts, label))
187+
Some(TempIf(condition.get, thenStmts.to(immutable.LinearSeq), elseStmts.to(immutable.LinearSeq), label))
180188
} else {
181189
None
182190
}
@@ -298,7 +306,7 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
298306
case "FALSE" => Some(FalseLiteral)
299307
case "FPCR" => Some(Register("FPCR", 32))
300308
// ignore the following
301-
case "__BranchTaken" => None
309+
case "__BranchTaken" => Some(LocalVar("__BranchTaken", BoolType))
302310
case "BTypeNext" => None
303311
case "BTypeCompatible" => None
304312
case "TPIDR_EL0" => Some(Register(name, 64))
@@ -682,7 +690,7 @@ class GTIRBLoader(parserMap: immutable.Map[String, List[InsnSemantics]]) {
682690
// ignore the following
683691
case "TRUE" => throw Exception(s"Boolean literal $name in LExpr ${ctx.getText}")
684692
case "FALSE" => throw Exception(s"Boolean literal $name in LExpr ${ctx.getText}")
685-
case "__BranchTaken" => None
693+
case "__BranchTaken" => Some(LocalVar("__BranchTaken", BoolType))
686694
case "BTypeNext" => None
687695
case "BTypeCompatible" => None
688696
case "TPIDR_EL0" => Some(Register(name, 64))

src/main/scala/translating/GTIRBToIR.scala

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import scala.collection.mutable.Map
1717
import scala.collection.mutable.ArrayBuffer
1818
import scala.collection.immutable
1919
import scala.jdk.CollectionConverters.*
20+
import scala.util.{Try, Success, Failure}
2021
import java.util.Base64
2122
import java.nio.charset.*
2223
import scala.util.boundary
@@ -25,6 +26,13 @@ import java.nio.ByteBuffer
2526
import util.intrusive_list.*
2627
import util.Logger
2728

29+
private def assigned(x: Statement): immutable.Set[Variable] = x match {
30+
case x: Assign => x.assignees
31+
case x: TempIf =>
32+
x.cond.variables ++ x.thenStmts.flatMap(assigned) ++ x.elseStmts.flatMap(assigned)
33+
case _ => immutable.Set.empty
34+
}
35+
2836
/** TempIf class, used to temporarily store information about Jumps so that multiple parse runs are not needed.
2937
* Specifically, this is useful in the case that the IF statment has multiple conditions( and elses) and as such many
3038
* extra blocks need to be created.
@@ -38,8 +46,8 @@ import util.Logger
3846
*/
3947
class TempIf(
4048
val cond: Expr,
41-
val thenStmts: mutable.Buffer[Statement],
42-
val elseStmts: mutable.Buffer[Statement],
49+
val thenStmts: immutable.Seq[Statement],
50+
val elseStmts: immutable.Seq[Statement],
4351
override val label: Option[String] = None
4452
) extends NOP(label)
4553

@@ -177,7 +185,9 @@ class GTIRBToIR(
177185

178186
val statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address)
179187
blockCount += 1
180-
block.statements.addAll(statements)
188+
for ((stmts, i) <- statements.zipWithIndex) {
189+
block.statements.addAll(insertPCIncrement(stmts))
190+
}
181191

182192
if (block.statements.isEmpty && !blockOutgoingEdges.contains(blockUUID)) {
183193
// remove blocks that are just nop padding
@@ -236,11 +246,33 @@ class GTIRBToIR(
236246
Program(procedures, intialProc, initialMemory)
237247
}
238248

249+
private def insertPCIncrement(isnStmts: immutable.Seq[Statement]): immutable.Seq[Statement] = {
250+
Try(isnStmts.last) match {
251+
case Success(x: TempIf) =>
252+
isnStmts.init :+ TempIf(x.cond, insertPCIncrement(x.thenStmts), insertPCIncrement(x.elseStmts))
253+
case Success(_) | Failure(_) => {
254+
val branchTaken = isnStmts.exists {
255+
case LocalAssign(LocalVar("__BranchTaken", BoolType, _), TrueLiteral, _) => true
256+
case _: TempIf => throw Exception("encountered TempIf not at end of statement list: " + isnStmts)
257+
case _ => false
258+
}
259+
val increment =
260+
if branchTaken then Seq()
261+
else
262+
Seq(
263+
LocalAssign(Register("_PC", 64), BinaryExpr(BVADD, Register("_PC", 64), BitVecLiteral(4, 64)), None),
264+
LocalAssign(LocalVar("__BranchTaken", BoolType), FalseLiteral)
265+
)
266+
increment ++: isnStmts
267+
}
268+
}
269+
}
270+
239271
private def removePCAssign(block: Block): Option[String] = {
240272
block.statements.last match {
241273
case last @ LocalAssign(lhs: Register, _, _) if lhs.name == "_PC" =>
242274
val label = last.label
243-
block.statements.remove(last)
275+
// block.statements.remove(last)
244276
label
245277
case _ => throw Exception(s"expected block ${block.label} to have a program counter assignment at its end")
246278
}
@@ -309,9 +341,20 @@ class GTIRBToIR(
309341
throw Exception(s"block ${byteStringToString(blockUUID)} is in multiple functions")
310342
}
311343
uuidToBlock += (blockUUID -> block)
312-
if (blockUUID == entranceUUID) {
344+
val isEntrance = blockUUID == entranceUUID
345+
if (isEntrance) {
313346
procedure.entryBlock = block
314347
}
348+
val checkPCStmt: Expr => Statement =
349+
if isEntrance
350+
then Assume(_, None, None, false)
351+
else Assert(_, None, None)
352+
block.address match {
353+
case Some(addr) =>
354+
val assertPC = checkPCStmt(BinaryExpr(BVEQ, Register("_PC", 64), BitVecLiteral(addr, 64)))
355+
block.statements.append(assertPC)
356+
case _ => ()
357+
}
315358
block
316359
}
317360

@@ -781,7 +824,7 @@ class GTIRBToIR(
781824

782825
val newBlocks = ArrayBuffer(trueBlock, falseBlock)
783826
procedure.addBlocks(newBlocks)
784-
block.statements.remove(tempIf)
827+
// block.statements.remove(tempIf)
785828

786829
GoTo(newBlocks)
787830
}

0 commit comments

Comments
 (0)