Skip to content

Commit 806d935

Browse files
committed
do dag traversal instead of tree
1 parent 5e11508 commit 806d935

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

src/kdrag/utils.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -804,25 +804,43 @@ def bysect(
804804
return by1
805805

806806

807-
def subterms(t: smt.ExprRef, into_binder=False):
807+
def subterms(t: smt.ExprRef, into_binder=False) -> list[smt.ExprRef]:
808808
"""Generate all subterms of a term
809809
810810
>>> x,y = smt.Ints("x y")
811811
>>> list(subterms(x + y == y))
812-
[x + y == y, y, x + y, y, x]
812+
[x + y == y, x + y, y, x]
813813
>>> list(subterms(smt.ForAll([x], x + y == y)))
814814
[ForAll(x, x + y == y)]
815815
>>> list(subterms(smt.ForAll([x], x + y == y), into_binder=True))
816-
[ForAll(x, x + y == y), Var(0) + y == y, y, Var(0) + y, y, Var(0)]
816+
[ForAll(x, x + y == y), Var(0) + y == y, Var(0) + y, y, Var(0)]
817817
"""
818818
todo = [t]
819+
seen = {t.get_id()}
820+
res = [t]
819821
while len(todo) > 0:
820822
x = todo.pop()
821-
yield x
822823
if smt.is_app(x):
823-
todo.extend(x.children())
824-
elif isinstance(x, smt.QuantifierRef) and into_binder:
825-
todo.append(x.body())
824+
children = x.children()
825+
new = [c for c in children if c.get_id() not in seen]
826+
todo.extend(new)
827+
res.extend(new)
828+
seen.update(c.get_id() for c in new)
829+
elif isinstance(x, smt.QuantifierRef):
830+
if into_binder:
831+
body = x.body()
832+
idx = body.get_id()
833+
if idx not in seen:
834+
seen.add(idx)
835+
todo.append(body)
836+
res.append(body)
837+
else:
838+
continue
839+
elif smt.is_var(x):
840+
continue
841+
else:
842+
raise Exception("Unexpected term in subterms", x)
843+
return res
826844

827845

828846
def sorts(t: smt.ExprRef):
@@ -833,29 +851,15 @@ def sorts(t: smt.ExprRef):
833851
yield t.sort()
834852

835853

836-
def consts(t: smt.ExprRef) -> set[smt.ExprRef]:
854+
def consts(t: smt.ExprRef) -> list[smt.ExprRef]:
837855
"""
838856
Return all constants in a term.
839857
840858
>>> x,y = smt.Ints("x y")
841-
>>> consts(x + (y * y) + 1) == {smt.IntVal(1), x, y}
859+
>>> set(consts(x + (y * y) + 1)) == {smt.IntVal(1), x, y}
842860
True
843861
"""
844-
ts = set()
845-
todo = [t]
846-
while todo:
847-
t = todo.pop()
848-
if smt.is_const(t):
849-
ts.add(t)
850-
elif smt.is_app(t):
851-
todo.extend(t.children())
852-
elif isinstance(t, smt.QuantifierRef):
853-
todo.append(t.body())
854-
elif smt.is_var(t):
855-
continue
856-
else:
857-
raise Exception("Unexpected term in consts", t)
858-
return ts
862+
return [e for e in subterms(t, into_binder=True) if smt.is_const(e)]
859863

860864

861865
def decls(t: smt.ExprRef) -> set[smt.FuncDeclRef]:

0 commit comments

Comments
 (0)