Skip to content

Commit cc3c6c0

Browse files
committed
GenericDispatch and filter example
1 parent c409e88 commit cc3c6c0

File tree

4 files changed

+243
-31
lines changed

4 files changed

+243
-31
lines changed

src/kdrag/contrib/pcode/asmspec.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,24 @@ def execute_insn(
465465
]
466466

467467

468+
def execute_spec_and_insn(
469+
tracestate0: TraceState,
470+
spec: AsmSpec,
471+
ctx: pcode.BinaryContext,
472+
verbose=True,
473+
) -> tuple[list[TraceState], list[VerificationCondition]]:
474+
"""
475+
Execute spec statements and then one instruction.
476+
"""
477+
addr = tracestate0.state.pc[0]
478+
specstmts = spec.addrmap.get(addr, [])
479+
tracestate, vcs = execute_spec_stmts(specstmts, tracestate0, ctx, verbose=verbose)
480+
if tracestate is None:
481+
return [], vcs
482+
new_tracestates = execute_insn(tracestate, ctx, verbose=verbose)
483+
return new_tracestates, vcs
484+
485+
468486
def init_trace_states(
469487
ctx: pcode.BinaryContext, mem: pcode.MemState, spec: AsmSpec, verbose=True
470488
) -> tuple[list[TraceState], list[VerificationCondition]]:
@@ -538,16 +556,95 @@ def run_all_paths(
538556
)
539557
)
540558
continue
541-
addr = tracestate.state.pc[0]
542-
specstmts = spec.addrmap.get(addr, [])
543-
tracestate, new_vcs = execute_spec_stmts(specstmts, tracestate, ctx)
559+
new_tracestates, new_vcs = execute_spec_and_insn(
560+
tracestate, spec, ctx, verbose=verbose
561+
)
544562
vcs.extend(new_vcs)
545-
if tracestate is not None:
546-
todo.extend(execute_insn(tracestate, ctx, verbose=verbose))
563+
todo.extend(new_tracestates)
564+
565+
# addr = tracestate.state.pc[0]
566+
# specstmts = spec.addrmap.get(addr, [])
567+
# tracestate, new_vcs = execute_spec_stmts(specstmts, tracestate, ctx)
568+
# vcs.extend(new_vcs)
569+
# if tracestate is not None:
570+
# todo.extend(execute_insn(tracestate, ctx, verbose=verbose))
547571

548572
return vcs
549573

550574

575+
class Debug:
576+
def __init__(self, ctx: pcode.BinaryContext, spec: AsmSpec):
577+
self.ctx = ctx
578+
self.spec: AsmSpec = spec
579+
self.tracestates: list[TraceState] = []
580+
self.vcs: list[VerificationCondition] = []
581+
self.breakpoints = set()
582+
583+
def spec_file(self, filename: str):
584+
self.spec = AsmSpec.of_file(filename, self.ctx)
585+
586+
def add_entry(self, name, precond=smt.BoolVal(True)):
587+
assert self.ctx.loader is not None, (
588+
"BinaryContext must be loaded before disassembling"
589+
)
590+
sym = self.ctx.loader.find_symbol(name)
591+
if sym is None:
592+
raise Exception(f"Symbol {name} not found in binary {self.ctx.filename}")
593+
self.spec.add_entry(name, sym.rebased_addr, precond)
594+
595+
def start(self, mem=None):
596+
if mem is None:
597+
mem = self.ctx.init_mem()
598+
tracestates, vcs = init_trace_states(self.ctx, mem, self.spec)
599+
self.tracestates = tracestates
600+
self.vcs = vcs
601+
602+
def breakpoint(self, addr):
603+
self.breakpoints.add(addr)
604+
605+
def step(self, n=1):
606+
for _ in range(n):
607+
tracestate = self.pop()
608+
new_tracestates, new_vcs = execute_spec_and_insn(
609+
tracestate, self.spec, self.ctx
610+
)
611+
self.vcs.extend(new_vcs)
612+
self.tracestates.extend(new_tracestates)
613+
614+
def run(self):
615+
while self.tracestates:
616+
if self.addr() in self.breakpoints:
617+
break
618+
self.step()
619+
620+
def pop(self):
621+
return self.tracestates.pop()
622+
623+
def addr(self):
624+
return self.tracestates[-1].state.pc[0]
625+
626+
def ghost(self, name):
627+
return self.tracestates[-1].ghost_env[name]
628+
629+
def reg(self, name):
630+
reg = self.ctx._subst_decls[name]
631+
return smt.simplify(
632+
self.ctx.substitute(self.tracestates[-1].state.memstate, reg)
633+
)
634+
635+
def ram(self, addr, size=None):
636+
if size is None:
637+
size = self.ctx.bits // 8
638+
return smt.simplify(
639+
self.tracestates[-1].state.memstate.getvalue_ram(addr, size)
640+
)
641+
642+
def insn(self):
643+
return self.ctx.disassemble(self.addr())
644+
645+
def model(self): ... # TODO.
646+
647+
551648
@dataclass
552649
class Results:
553650
successes: list[str] = dataclasses.field(default_factory=list)

src/kdrag/notation.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ class SortDispatch:
5959
Not(True)
6060
"""
6161

62-
def __init__(self, name=None, default=None, pointwise=False):
62+
def __init__(self, name=None, default=None, pointwise=False, default_factory=None):
6363
self.methods = {}
6464
self.default = default
6565
self.name = name
6666
self.pointwise = pointwise
67+
self.default_factory = default_factory
6768

6869
def register(self, sort, func):
6970
self.methods[sort] = func
@@ -96,30 +97,36 @@ def __call__(self, *args, **kwargs):
9697
"""
9798
if not args:
9899
raise TypeError("No arguments provided")
100+
x0 = args[0]
99101
sort = args[0].sort()
100-
res = self.methods.get(sort, self.default)
101-
if res is None:
102-
x0 = args[0]
103-
if self.pointwise and smt.is_func(x0):
104-
doms = smt.domains(x0)
105-
vs = [smt.FreshConst(d, prefix=f"x{n}") for n, d in enumerate(doms)]
106-
sort = x0.sort()
107-
# sorts are same or attempt coerce/lift
108-
return smt.Lambda(
109-
vs,
110-
self(
111-
*[
112-
arg(*vs)
113-
if isinstance(arg, smt.ExprRef) and arg.sort() == x0.sort()
114-
else arg
115-
for arg in args
116-
]
117-
),
118-
)
119-
else:
120-
raise NotImplementedError(
121-
f"No implementation of {self.name} for sort {sort}. Register a definition via {self.name}.register({sort}, your_code)",
122-
)
102+
res = self.methods.get(sort)
103+
if res is not None:
104+
return res(*args, **kwargs)
105+
elif self.default is not None:
106+
return self.default(*args, **kwargs)
107+
elif self.default_factory is not None:
108+
res = self.default_factory(sort)
109+
self.register(sort, res)
110+
return res(*args, **kwargs)
111+
elif self.pointwise and smt.is_func(x0):
112+
doms = smt.domains(x0)
113+
vs = [smt.FreshConst(d, prefix=f"x{n}") for n, d in enumerate(doms)]
114+
# sorts are same or attempt coerce/lift
115+
return smt.Lambda(
116+
vs,
117+
self(
118+
*[
119+
arg(*vs)
120+
if isinstance(arg, smt.ExprRef) and arg.sort() == x0.sort()
121+
else arg
122+
for arg in args
123+
]
124+
),
125+
)
126+
else:
127+
raise NotImplementedError(
128+
f"No implementation of {self.name} for sort {sort}. Register a definition via {self.name}.register({sort}, your_code)",
129+
)
123130
return res(*args, **kwargs)
124131

125132
def define(self, args, body):
@@ -132,6 +139,21 @@ def define(self, args, body):
132139
return defn
133140

134141

142+
def GenericDispatch(default_factory) -> SortDispatch:
143+
"""
144+
A decorator version of SortDispatch with a default factory.
145+
This is useful for definition Sort generic definitions.
146+
147+
>>> @GenericDispatch
148+
... def id(sort):
149+
... x = smt.Const("x", sort)
150+
... return kd.define("id", [x], x)
151+
>>> id(smt.IntVal(3))
152+
id(3)
153+
"""
154+
return SortDispatch(default_factory.__name__, default_factory=default_factory)
155+
156+
135157
call = SortDispatch(name="call")
136158
"""Sort based dispatch for `()` call syntax"""
137159
smt.ExprRef.__call__ = lambda x, *y, **kwargs: call(x, *y, **kwargs) # type: ignore
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import functools
2+
import kdrag.smt as smt
3+
import kdrag as kd
4+
import kdrag.theories.set as set_
5+
6+
7+
@functools.cache
8+
def Filter(T: smt.SortRef):
9+
"""
10+
A sort generic filter over sets of type T.
11+
12+
>>> Filter(smt.RealSort())
13+
Filter_Real
14+
"""
15+
dt = kd.Struct(f"Filter_{T}", ("sets", smt.SetSort(smt.SetSort(T))))
16+
A, B = smt.Consts("A B", smt.SetSort(T))
17+
F = smt.Const("F", dt)
18+
kd.notation.call.register(dt, lambda F, A: dt.sets(F)[A])
19+
kd.notation.wf.define(
20+
[F],
21+
smt.And(
22+
F(smt.FullSet(T)),
23+
smt.ForAll([A, B], F(A), A <= B, F(B)),
24+
smt.ForAll([A, B], F(A), F(B), F(A & B)),
25+
),
26+
)
27+
return dt
28+
29+
30+
class FilterMod:
31+
"""
32+
A module encapsulating filter theory over sets of type T.
33+
34+
>>> FilterMod(smt.RealSort()).filter_full
35+
|= ForAll(F, Implies(wf(F), sets(F)[K(Real, True)]))
36+
"""
37+
38+
def __init__(self, T: smt.SortRef):
39+
self.T = T
40+
self.S = Filter(T)
41+
F = smt.Const("F", self.S)
42+
set_.Set(T)
43+
A, B = smt.Consts("A B", smt.SetSort(T))
44+
self.filter_full = kd.prove(kd.QForAll([F], F(smt.FullSet(T))), unfold=1)
45+
self.filter_mono = kd.prove(
46+
kd.QForAll([F], smt.ForAll([A, B], F(A), A <= B, F(B))), unfold=1
47+
)
48+
self.filter_inter = kd.prove(
49+
kd.QForAll([F], smt.ForAll([A, B], F(A), F(B), F(A & B))), unfold=1
50+
)

tests/test_pcode.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
from kdrag.all import *
3-
from kdrag.contrib.pcode.asmspec import assemble_and_check_str, AsmSpec, run_all_paths, kd_macro
3+
from kdrag.contrib.pcode.asmspec import assemble_and_check_str, AsmSpec, run_all_paths, kd_macro, Debug
44
import kdrag.contrib.pcode as pcode
55

66
import pytest
@@ -175,6 +175,15 @@ def test_cli():
175175
res = subprocess.run("python3 -m kdrag.contrib.pcode /tmp/mov42.s", shell=True, capture_output=True, text=True)
176176
assert "verification conditions failed. ❌❌❌❌" in res.stdout
177177

178+
def riscv_as(code):
179+
with open("/tmp/mov42.s", "w") as f:
180+
f.write(code)
181+
f.flush()
182+
res = subprocess.run("""nix-shell -p pkgsCross.riscv32-embedded.buildPackages.gcc \
183+
--run "riscv32-none-elf-as /tmp/mov42.s -o /tmp/mov42.o" """, shell=True, check=True, capture_output=True, text=True)
184+
185+
186+
178187
@pytest.mark.slow
179188
def test_riscv32():
180189
code = """
@@ -221,4 +230,38 @@ def test_riscv32_write():
221230
vcs = run_all_paths(ctx, spec)
222231
assert len(vcs) == 1
223232
for vc in vcs:
224-
vc.verify(ctx)
233+
vc.verify(ctx)
234+
235+
@pytest.mark.slow
236+
def test_debugger():
237+
code = """
238+
.include "/tmp/knuckle.s"
239+
.option norvc
240+
.text
241+
.globl _start
242+
_start:
243+
li t0, 42 # load immediate 42 into t0
244+
li t1, 1000 # load address 1000 into t1
245+
sw t0, 0(t1) # store t0 at memory address
246+
ret
247+
"""
248+
riscv_as(code)
249+
ctx = pcode.BinaryContext("/tmp/mov42.o", langid="RISCV:LE:32:default")
250+
spec = AsmSpec.of_file("/tmp/mov42.s", ctx)
251+
d = Debug(ctx, spec)
252+
d.add_entry("_start")
253+
d.start()
254+
print(d.insn())
255+
print(d.addr())
256+
assert d.reg("t0").eq(smt.BitVecVal(42,32))
257+
assert not d.reg("t1").eq(smt.BitVecVal(1000,32))
258+
d.step()
259+
assert d.reg("t1").eq(smt.BitVecVal(1000,32))
260+
d.step()
261+
assert d.ram(smt.BitVecVal(1000,32), 4).eq(smt.BitVecVal(42,32))
262+
263+
264+
265+
266+
267+

0 commit comments

Comments
 (0)