Skip to content

Commit 6fae002

Browse files
committed
pcode interpreter uses folded expressions. Moved tptp and smtlib functionlaity to printers out of solvers
1 parent 8da0d3f commit 6fae002

File tree

7 files changed

+195
-94
lines changed

7 files changed

+195
-94
lines changed

src/kdrag/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@
2929

3030
define = kernel.define
3131

32+
33+
def define_const(name: str, body: smt.ExprRef) -> smt.ExprRef:
34+
"""
35+
Define a constant.
36+
37+
>>> x = define_const("define_const_example", smt.IntVal(42))
38+
>>> x
39+
define_const_example
40+
>>> rewrite.unfold(x)
41+
42
42+
"""
43+
# TODO: Remove this type ignore and rename all uses of define to define_const where no constants expected
44+
# arguably define is define_fun
45+
return kernel.define(name, [], body) # type: ignore
46+
47+
3248
FreshVar = kernel.FreshVar
3349

3450
FreshVars = tactics.FreshVars

src/kdrag/contrib/pcode/__init__.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,16 @@ def getvalue(self, vnode: pypcode.Varnode) -> smt.BitVecRef | int:
129129
if vnode.space.name == "ram":
130130
mem = self.ram
131131
elif vnode.space.name == "register":
132-
mem = self.register
132+
return bv.select_concat(
133+
self.register,
134+
smt.BitVec("&" + vnode.getRegisterName(), self.bits),
135+
vnode.size,
136+
)
133137
elif vnode.space.name == "unique":
134138
mem = self.unique
135139
else:
136140
raise ValueError(f"Unknown memory space: {vnode.space.name}")
137-
return bv.SelectConcat(
141+
return bv.select_concat(
138142
mem, smt.BitVecVal(vnode.offset, self.bits), vnode.size
139143
)
140144

@@ -144,42 +148,44 @@ def setvalue(self, vnode: pypcode.Varnode, value: smt.BitVecRef):
144148
if space == "ram":
145149
return self.setvalue_ram(offset, value)
146150
elif space == "register":
147-
return self.set_register(vnode.offset, value)
151+
return self._replace(
152+
register=bv.store_concat(
153+
self.register,
154+
smt.BitVec("&" + vnode.getRegisterName(), self.bits),
155+
value,
156+
)
157+
)
148158
elif space == "unique":
149-
return self._replace(unique=bv.StoreConcat(self.unique, offset, value))
159+
return self._replace(unique=bv.store_concat(self.unique, offset, value))
150160
else:
151161
raise ValueError(f"Unknown memory space: {space}")
152162

153-
def set_register(self, offset: smt.BitVecRef | int, value: smt.BitVecRef):
154-
# This is mainly for the purpose of manually setting PC in evaluator loop
155-
if not isinstance(offset, smt.BitVecRef):
156-
offset1 = smt.BitVecVal(offset, self.bits)
157-
else:
158-
offset1 = offset
159-
return self._replace(
160-
register=bv.StoreConcat(
161-
self.register,
162-
smt.BitVecVal(offset1, self.bits),
163-
value,
164-
)
165-
)
166-
167163
def getvalue_ram(self, offset: smt.BitVecRef | int, size: int) -> smt.BitVecRef:
168164
# TODO: update read?
169-
return bv.SelectConcat(self.ram, offset, size)
165+
return bv.select_concat(self.ram, offset, size)
170166

171167
def setvalue_ram(self, offset: smt.BitVecRef | int, value: smt.BitVecRef):
172168
if not isinstance(offset, smt.BitVecRef):
173169
offset1 = smt.BitVecVal(offset, self.bits)
174170
else:
175171
offset1 = offset
176172
return self._replace(
177-
ram=bv.StoreConcat(self.ram, offset1, value),
173+
ram=bv.store_concat(self.ram, offset1, value),
178174
write=self.write + [(offset1, value.size())], # fun.MultiStore(
179175
# self.write, offset1, *([smt.BoolVal(True)] * value.size())
180176
# ),
181177
)
182178

179+
def __str__(self):
180+
# use sexpr form which uses `let` for shared expressions.
181+
# Using sexpr on And expression returns both memory and ram with shared expressions lifted
182+
cur_ram = smt.Const("CUR_RAM", self.ram.sort())
183+
cur_register = smt.Const("CUR_REGFILE", self.register.sort())
184+
return f"MemState({smt.And(cur_ram == self.ram, cur_register == self.register).sexpr()})"
185+
186+
def __repr__(self):
187+
return self.__str__()
188+
183189

184190
# Pure Operations
185191

@@ -226,9 +232,8 @@ def executeSubpiece(op: pypcode.PcodeOp, memstate: MemState) -> MemState:
226232
def executePopcount(op: pypcode.PcodeOp, memstate: MemState) -> MemState:
227233
assert op.output is not None
228234
in1 = memstate.getvalue(op.inputs[0])
229-
out = smt.BitVecVal(0, op.inputs[0].size * 8)
230-
for i in range(op.inputs[0].size * 8):
231-
out += (in1 >> i) & 1
235+
assert isinstance(in1, smt.BitVecRef)
236+
out = bv.popcount(in1)
232237
outsize = op.output.size * 8
233238
insize = op.inputs[0].size * 8
234239
if outsize > insize:
@@ -307,21 +312,40 @@ def __init__(self, filename=None, langid="x86:LE:64:default"):
307312
self.filename = None
308313
self.loader = None
309314
self.bin_hash = hash((filename, langid))
315+
self.ctx = pypcode.Context(langid) # TODO: derive from cle
310316
ainfo = archinfo.ArchPcode(langid)
311-
self.pc: tuple[int, int] = ainfo.registers[
317+
pc: tuple[int, int] = ainfo.registers[
312318
"pc"
313319
] # TODO: handle different archs? Or will "pc" always work?
320+
for name, vnode in self.ctx.registers.items():
321+
if vnode.offset == pc[0] and vnode.size == pc[1]:
322+
self.pc = vnode
323+
break
324+
else:
325+
raise ValueError("Could not find PC register", pc)
326+
314327
self.bits = ainfo.bits
315-
assert self.bits == self.pc[1] * 8
328+
assert self.bits == self.pc.size * 8
316329
self.memory_endness = ainfo.memory_endness # TODO
317330
self.register_endness = ainfo.register_endness # TODO
318-
self.ctx = pypcode.Context(langid) # TODO: derive from cle
319331

320332
# Defintions that are used but may need to be unfolded
321-
self.definitions: list[smt.FuncDeclRef] = list(bv.select64_le.values())
322-
self.definitions.extend(bv.select64_be.values())
323-
self.definitions.extend(bv.select32_le.values())
324-
self.definitions.extend(bv.select32_be.values())
333+
# &reg is also added in load
334+
self.definitions = [
335+
bv.select_concats(bits, size, le=le)
336+
for le in [True, False]
337+
for bits in [32, 64]
338+
for size in [16, 32, 64]
339+
]
340+
self.definitions.extend([bv.popcounts(size) for size in [8, 16, 32, 64]])
341+
self.definitions.extend(
342+
[
343+
bv.store_concats(bits, size, le=le)
344+
for le in [True, False]
345+
for bits in [32, 64]
346+
for size in [16, 32, 64]
347+
]
348+
)
325349
if filename is not None:
326350
self.load(filename)
327351

@@ -339,6 +363,11 @@ def load(self, main_binary, **kwargs):
339363
name: smt.BitVec(name, vnode.size * 8)
340364
for name, vnode in self.ctx.registers.items()
341365
}
366+
# Make offsets available as definitions. &regname is offset in regfile
367+
self.definitions.extend(
368+
kd.define_const("&" + name, smt.BitVecVal(vnode.offset, self.bits)).decl()
369+
for name, vnode in self.ctx.registers.items()
370+
)
342371
# support %reg names
343372
decls.update(
344373
{
@@ -562,8 +591,8 @@ def sym_execute(
562591
if pcode_pc == 0:
563592
max_insns1 = max_insns - 1
564593
# pcode does not have explicit PC updates, but we want them
565-
memstate2 = memstate1.set_register(
566-
self.pc[0], smt.BitVecVal(addr, self.pc[1] * 8)
594+
memstate2 = memstate1.setvalue(
595+
self.pc, smt.BitVecVal(addr, self.pc.size * 8)
567596
)
568597
else:
569598
max_insns1 = max_insns
@@ -582,8 +611,8 @@ def sym_execute(
582611
): # pcode_pc == 0 means we are at the start of an instruction. Kind of. There are some edge cases, TODO
583612
max_insns -= 1
584613
# pcode does not have explicit PC updates, but we want them
585-
memstate1 = memstate1.set_register(
586-
self.pc[0], smt.BitVecVal(pc1[0], self.pc[1] * 8)
614+
memstate1 = memstate1.setvalue(
615+
self.pc, smt.BitVecVal(pc1[0], self.pc.size * 8)
587616
)
588617
if pc1[0] in breakpoints:
589618
res.append(SimState(memstate1, pc1, path_cond))
@@ -600,7 +629,7 @@ def get_reg(self, memstate: MemState, regname: str) -> smt.BitVecRef:
600629
>>> ctx = BinaryContext()
601630
>>> memstate = MemState.Const("test_mem")
602631
>>> memstate = ctx.set_reg(memstate, "RAX", smt.BitVec("RAX", 64))
603-
>>> ctx.get_reg(memstate, "RAX")
632+
>>> ctx.simplify(ctx.get_reg(memstate, "RAX"))
604633
RAX
605634
"""
606635
vnode = self.ctx.registers[regname]
@@ -622,9 +651,12 @@ def init_mem(self) -> MemState:
622651
>>> ctx = BinaryContext()
623652
>>> memstate = ctx.init_mem()
624653
>>> ctx.get_reg(memstate, "RAX")
625-
RAX!...
654+
select64le(register(state0), &RAX)
655+
"""
656+
memstate = MemState.Const("state0", bits=self.bits)
657+
return memstate
658+
# Old code to initialize memory with dummy regnames. Maybe still useful?
626659
"""
627-
memstate = MemState.Const("mem0", bits=self.bits)
628660
free_offset = 0
629661
for name, vnode in self.ctx.registers.items():
630662
# interestingness heuristic on length of name
@@ -637,6 +669,7 @@ def init_mem(self) -> MemState:
637669
)
638670
free_offset = vnode.offset + vnode.size
639671
return memstate
672+
"""
640673

641674
def get_regs(self, memstate: MemState) -> dict[str, smt.BitVecRef]:
642675
"""
@@ -698,11 +731,22 @@ def unfold(self, expr: smt.ExprRef) -> smt.ExprRef:
698731
x
699732
>>> import kdrag.theories.bitvec as bv
700733
>>> ram = smt.Array("ram", BV[64], BV[8])
701-
>>> smt.simplify(ctx.unfold(bv.select64_le[16](ram, x)))
734+
>>> smt.simplify(ctx.unfold(bv.select_concat(ram, x, 2)))
702735
Concat(ram[1 + x], ram[x])
703736
"""
704737
return kd.kernel.unfold(expr, self.definitions)[0]
705738

739+
def simplify(self, expr: smt.ExprRef) -> smt.ExprRef:
740+
"""
741+
Call simplify and unfold if unfolding makes expression smaller.
742+
"""
743+
e1 = smt.simplify(expr)
744+
e2 = smt.simplify(self.unfold(expr))
745+
if len(e2.sexpr()) < len(e1.sexpr()):
746+
return e2
747+
else:
748+
return e1
749+
706750
def model_registers(
707751
self,
708752
model: smt.ModelRef,
@@ -724,3 +768,16 @@ def test_pcode():
724768
tx = ctx.translate(b"\xf7\xd8") # neg %eax
725769
for op in tx.ops:
726770
pass
771+
772+
773+
class StateExpr(NamedTuple):
774+
ctx: BinaryContext
775+
expr: smt.ExprRef
776+
777+
def to_lambda(self) -> smt.QuantifierRef:
778+
mem = smt.Const("mem", MemStateSort[self.ctx.bits])
779+
memstate = MemState.Const("mem", bits=self.ctx.bits)
780+
return smt.Lambda([mem], self(memstate))
781+
782+
def __call__(self, memstate: MemState) -> smt.ExprRef:
783+
return self.ctx.substitute(memstate, self.expr)

src/kdrag/contrib/pcode/asmspec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,8 @@ def init_trace_states(
513513
for n, stmt in enumerate(specstmts):
514514
if isinstance(stmt, Cut) or isinstance(stmt, Entry):
515515
init_trace_id += 1
516-
mem1 = mem.set_register(
517-
ctx.pc[0], smt.BitVecVal(addr, ctx.pc[1] * 8)
516+
mem1 = mem.setvalue(
517+
ctx.pc, smt.BitVecVal(addr, ctx.pc.size * 8)
518518
) # set pc to start addr before entering code
519519
precond = ctx.substitute(mem1, stmt.expr) # No ghost? Use substitute?
520520
ghost_env = {
@@ -715,7 +715,7 @@ def ghost(self, name):
715715

716716
def reg(self, name):
717717
reg = self.ctx.state_vars[name]
718-
return smt.simplify(
718+
return self.ctx.simplify(
719719
self.ctx.substitute(self.tracestates[-1].state.memstate, reg)
720720
)
721721

@@ -725,7 +725,7 @@ def ram(self, addr, size=None):
725725
"""
726726
if size is None:
727727
size = self.ctx.bits // 8
728-
return smt.simplify(
728+
return self.ctx.simplify(
729729
self.tracestates[-1].state.memstate.getvalue_ram(addr, size)
730730
)
731731

src/kdrag/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def define(
268268
print("WARNING: Redefining function", f, "from", defns[f].ax, "to", def_ax.thm)
269269
defns[f] = defn
270270
if len(args) == 0:
271-
return f() # Convenience
271+
return f() # Convenience. TODO: Remove
272272
else:
273273
return f
274274

src/kdrag/printers/tptp.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def mangle_decl(d: smt.FuncDeclRef, env=[]):
66
"""Mangle a declaration to a tptp name. SMTLib supports type based overloading, TPTP does not."""
77
# single quoted (for operators) + underscore + hex id
88
id_, name = d.get_id(), d.name()
9-
name = name.replace("!", "bang")
9+
name = name.replace("!", "__")
1010
# TODO: mangling of operators is busted
1111
# name = name.replace("*", "star")
1212
assert id_ >= 0x80000000
@@ -70,7 +70,12 @@ def neg(e: smt.BoolRef) -> str:
7070
def expr_to_tptp(
7171
expr: smt.ExprRef, env=None, format="thf", theories=True, no_mangle=None
7272
) -> str:
73-
"""Pretty print expr as TPTP"""
73+
"""Pretty print expr as TPTP
74+
75+
>>> x,y = smt.Ints("x y")
76+
>>> expr_to_tptp(smt.ForAll([x,y], smt.Implies(x < y, x + 1 <= y)))
77+
'(![X...:$int, Y...:$int] : ($less(X...,Y...) => $lesseq($sum(X...,1),Y...))'
78+
"""
7479
if env is None:
7580
env = []
7681
if no_mangle is None:
@@ -190,7 +195,13 @@ def expr_to_tptp(
190195

191196

192197
def sort_to_tptp(sort: smt.SortRef):
193-
"""Pretty print sort as tptp"""
198+
"""Pretty print sort as tptp
199+
200+
>>> sort_to_tptp(smt.IntSort())
201+
'$int'
202+
>>> sort_to_tptp(smt.ArraySort(smt.IntSort(), smt.BoolSort()))
203+
'($int > $o)'
204+
"""
194205
name = sort.name()
195206
if name == "Int":
196207
return "$int"

0 commit comments

Comments
 (0)