@@ -17,6 +17,7 @@ import scala.collection.mutable.Map
17
17
import scala .collection .mutable .ArrayBuffer
18
18
import scala .collection .immutable
19
19
import scala .jdk .CollectionConverters .*
20
+ import scala .util .{Try , Success , Failure }
20
21
import java .util .Base64
21
22
import java .nio .charset .*
22
23
import scala .util .boundary
@@ -25,6 +26,13 @@ import java.nio.ByteBuffer
25
26
import util .intrusive_list .*
26
27
import util .Logger
27
28
29
+ private def assigned (x : Statement ): immutable.Set [Variable ] = x match {
30
+ case x : Assign => x.assignees
31
+ case x : TempIf =>
32
+ x.cond.variables ++ x.thenStmts.flatMap(assigned) ++ x.elseStmts.flatMap(assigned)
33
+ case _ => immutable.Set .empty
34
+ }
35
+
28
36
/** TempIf class, used to temporarily store information about Jumps so that multiple parse runs are not needed.
29
37
* Specifically, this is useful in the case that the IF statment has multiple conditions( and elses) and as such many
30
38
* extra blocks need to be created.
@@ -38,8 +46,8 @@ import util.Logger
38
46
*/
39
47
class TempIf (
40
48
val cond : Expr ,
41
- val thenStmts : mutable. Buffer [Statement ],
42
- val elseStmts : mutable. Buffer [Statement ],
49
+ val thenStmts : immutable. Seq [Statement ],
50
+ val elseStmts : immutable. Seq [Statement ],
43
51
override val label : Option [String ] = None
44
52
) extends NOP (label)
45
53
@@ -177,7 +185,9 @@ class GTIRBToIR(
177
185
178
186
val statements = semanticsLoader.visitBlock(blockUUID, blockCount, block.address)
179
187
blockCount += 1
180
- block.statements.addAll(statements)
188
+ for ((stmts, i) <- statements.zipWithIndex) {
189
+ block.statements.addAll(insertPCIncrement(stmts))
190
+ }
181
191
182
192
if (block.statements.isEmpty && ! blockOutgoingEdges.contains(blockUUID)) {
183
193
// remove blocks that are just nop padding
@@ -236,11 +246,33 @@ class GTIRBToIR(
236
246
Program (procedures, intialProc, initialMemory)
237
247
}
238
248
249
+ private def insertPCIncrement (isnStmts : immutable.Seq [Statement ]): immutable.Seq [Statement ] = {
250
+ Try (isnStmts.last) match {
251
+ case Success (x : TempIf ) =>
252
+ isnStmts.init :+ TempIf (x.cond, insertPCIncrement(x.thenStmts), insertPCIncrement(x.elseStmts))
253
+ case Success (_) | Failure (_) => {
254
+ val branchTaken = isnStmts.exists {
255
+ case LocalAssign (LocalVar (" __BranchTaken" , BoolType , _), TrueLiteral , _) => true
256
+ case _ : TempIf => throw Exception (" encountered TempIf not at end of statement list: " + isnStmts)
257
+ case _ => false
258
+ }
259
+ val increment =
260
+ if branchTaken then Seq ()
261
+ else
262
+ Seq (
263
+ LocalAssign (Register (" _PC" , 64 ), BinaryExpr (BVADD , Register (" _PC" , 64 ), BitVecLiteral (4 , 64 )), None ),
264
+ LocalAssign (LocalVar (" __BranchTaken" , BoolType ), FalseLiteral )
265
+ )
266
+ increment ++: isnStmts
267
+ }
268
+ }
269
+ }
270
+
239
271
private def removePCAssign (block : Block ): Option [String ] = {
240
272
block.statements.last match {
241
273
case last @ LocalAssign (lhs : Register , _, _) if lhs.name == " _PC" =>
242
274
val label = last.label
243
- block.statements.remove(last)
275
+ // block.statements.remove(last)
244
276
label
245
277
case _ => throw Exception (s " expected block ${block.label} to have a program counter assignment at its end " )
246
278
}
@@ -309,9 +341,20 @@ class GTIRBToIR(
309
341
throw Exception (s " block ${byteStringToString(blockUUID)} is in multiple functions " )
310
342
}
311
343
uuidToBlock += (blockUUID -> block)
312
- if (blockUUID == entranceUUID) {
344
+ val isEntrance = blockUUID == entranceUUID
345
+ if (isEntrance) {
313
346
procedure.entryBlock = block
314
347
}
348
+ val checkPCStmt : Expr => Statement =
349
+ if isEntrance
350
+ then Assume (_, None , None , false )
351
+ else Assert (_, None , None )
352
+ block.address match {
353
+ case Some (addr) =>
354
+ val assertPC = checkPCStmt(BinaryExpr (BVEQ , Register (" _PC" , 64 ), BitVecLiteral (addr, 64 )))
355
+ block.statements.append(assertPC)
356
+ case _ => ()
357
+ }
315
358
block
316
359
}
317
360
@@ -781,7 +824,7 @@ class GTIRBToIR(
781
824
782
825
val newBlocks = ArrayBuffer (trueBlock, falseBlock)
783
826
procedure.addBlocks(newBlocks)
784
- block.statements.remove(tempIf)
827
+ // block.statements.remove(tempIf)
785
828
786
829
GoTo (newBlocks)
787
830
}
0 commit comments