@@ -579,7 +579,7 @@ def prune(
579579
580580
581581def bysect (
582- thm , by0 : list [kd .kernel .Proof ] | dict [object , kd .kernel .Proof ], ** kwargs
582+ thm , by : list [kd .kernel .Proof ] | dict [object , kd .kernel .Proof ], ** kwargs
583583) -> Sequence [tuple [object , kd .kernel .Proof ]]:
584584 """
585585 Bisect the `by` list to find a minimal set of premises that prove `thm`. Presents the same interface as `prove`
@@ -589,29 +589,29 @@ def bysect(
589589 >>> bysect(x == z, by=by)
590590 [(1, |= x == y), (3, |= y == z)]
591591 """
592- if isinstance (by0 , list ):
593- by = list (enumerate (by0 ))
594- elif isinstance (by0 , dict ):
595- by = list (by0 .items ())
592+ if isinstance (by , list ):
593+ by1 = list (enumerate (by ))
594+ elif isinstance (by , dict ):
595+ by1 = list (by .items ())
596596 else :
597597 raise ValueError ("by must be a list or dict" )
598598 n = 2
599- while len (by ) >= 2 :
600- subset_size = len (by ) // n
601- for i in range (0 , len (by ), subset_size ):
602- rest = by [:i ] + by [i + subset_size :]
599+ while len (by1 ) >= 2 :
600+ subset_size = len (by1 ) // n
601+ for i in range (0 , len (by1 ), subset_size ):
602+ rest = by1 [:i ] + by1 [i + subset_size :]
603603 try :
604604 kd .prove (thm , by = [b for _ , b in rest ], ** kwargs )
605- by = rest
605+ by1 = rest
606606 n = max (n - 1 , 2 )
607607 break
608608 except Exception as _ :
609609 pass
610610 else :
611- if n == len (by ):
611+ if n == len (by1 ):
612612 break
613- n = min (len (by ), n * 2 )
614- return by
613+ n = min (len (by1 ), n * 2 )
614+ return by1
615615
616616
617617def subterms (t : smt .ExprRef , into_binder = False ):
@@ -780,22 +780,28 @@ def ast_size_sexpr(t: smt.AstRef) -> int:
780780@dataclass (frozen = True )
781781class QuantifierHole :
782782 vs : list [smt .ExprRef ]
783- # orig_vs : list[smt.ExprRef] to be able to exactly reconstruct original term?
783+ orig_vs : list [smt .ExprRef ] # to be able to exactly reconstruct original term?
784+
785+ def has_right (self ) -> bool :
786+ return False
784787
785788
786789class LambdaHole (QuantifierHole ):
787790 def wrap (self , body : smt .ExprRef ) -> smt .ExprRef :
788- return smt .Lambda (self .vs , body )
791+ body = smt .substitute (body , * zip (self .vs , self .orig_vs ))
792+ return smt .Lambda (self .orig_vs , body )
789793
790794
791795class ForAllHole (QuantifierHole ):
792796 def wrap (self , body : smt .ExprRef ) -> smt .ExprRef :
793- return smt .ForAll (self .vs , body )
797+ body = smt .substitute (body , * zip (self .vs , self .orig_vs ))
798+ return smt .ForAll (self .orig_vs , body )
794799
795800
796801class ExistsHole (QuantifierHole ):
797802 def wrap (self , body : smt .ExprRef ) -> smt .ExprRef :
798- return smt .Exists (self .vs , body )
803+ body = smt .substitute (body , * zip (self .vs , self .orig_vs ))
804+ return smt .Exists (self .orig_vs , body )
799805
800806
801807@dataclass (frozen = True )
@@ -835,23 +841,28 @@ class Zipper:
835841 >>> t = smt.Lambda([x,y], (x + y) * (y + z))
836842 >>> z1 = Zipper.from_term(t)
837843 >>> z1.open_binder().arg(1).left().arg(0)
838- Zipper(ctx=[LambdaHole(vs=[X!..., Y!...]), DeclHole(f=*, _left=(), _right=(Y!... + z,)), DeclHole(f=+, _left=(), _right=(Y!...,))], t=X!...)
839- >>> z1.pop ().pop ().pop ()
840- Zipper(ctx=[], t=Lambda([X!..., Y!... ], (X!... + Y!... )*(Y!... + z)))
844+ Zipper(ctx=[LambdaHole(vs=[X!..., Y!...], orig_vs=[x, y] ), DeclHole(f=*, _left=(), _right=(Y!... + z,)), DeclHole(f=+, _left=(), _right=(Y!...,))], t=X!...)
845+ >>> z1.up ().up ().up ()
846+ Zipper(ctx=[], t=Lambda([x, y ], (x + y )*(y + z)))
841847 """
842848
843- ctx : list [Hole ] # trail / stack
849+ ctx : list [Hole ] # trail / stack,. Consider saving old term
844850 t : smt .ExprRef
845851
846852 @classmethod
847853 def from_term (cls , t : smt .ExprRef ) -> "Zipper" :
848854 return cls ([], t )
849855
850- def pop (self ) -> "Zipper" : # up?
856+ def up (self ) -> "Zipper" : # up?
851857 hole = self .ctx .pop ()
852858 self .t = hole .wrap (self .t )
853859 return self
854860
861+ def rebuild (self ) -> smt .ExprRef :
862+ while self .ctx :
863+ self .up ()
864+ return self .t
865+
855866 def copy (self ) -> "Zipper" :
856867 return Zipper (self .ctx .copy (), self .t )
857868
@@ -879,19 +890,75 @@ def arg(self, n: int) -> "Zipper":
879890
880891 def open_binder (self ) -> "Zipper" :
881892 assert isinstance (self .t , smt .QuantifierRef )
893+ t = self .t
894+ orig_vs , _body = kd .utils .open_binder_unhygienic (
895+ t
896+ ) # TODO: don't need to build body
882897 vs , body = kd .utils .open_binder (self .t )
883898 if self .t .is_forall ():
884- hole = ForAllHole (vs )
899+ hole = ForAllHole (vs , orig_vs )
885900 elif self .t .is_exists ():
886- hole = ExistsHole (vs )
901+ hole = ExistsHole (vs , orig_vs )
887902 elif self .t .is_lambda ():
888- hole = LambdaHole (vs )
903+ hole = LambdaHole (vs , orig_vs )
889904 else :
890905 raise NotImplementedError ("Unknown quantifier type" , self .t )
891906 self .ctx .append (hole )
892907 self .t = body
893908 return self
894909
910+ def __iter__ (self ):
911+ return self
912+
913+ def __next__ (self ) -> smt .ExprRef :
914+ """
915+ All subterms of the term in a pre-order traversal.
916+
917+ >>> x,y,z = smt.Ints("x y z")
918+ >>> list(Zipper([], x + y*z))
919+ [x, y*z, y, z]
920+ """
921+ if isinstance (self .t , smt .QuantifierRef ):
922+ self .open_binder ()
923+ return self .t
924+ elif smt .is_const (self .t ):
925+ while len (self .ctx ) != 0 and not self .ctx [- 1 ].has_right ():
926+ self .up ()
927+ if len (self .ctx ) == 0 :
928+ raise StopIteration
929+ else :
930+ self .right ()
931+ return self .t
932+ elif smt .is_app (self .t ):
933+ self .arg (0 )
934+ return self .t
935+ else :
936+ raise ValueError ("Unexpected term in Zipper iteration" , self .t )
937+
938+ def pmatch (
939+ self , vs : list [smt .ExprRef ], pat : smt .ExprRef
940+ ) -> Optional [dict [smt .ExprRef , smt .ExprRef ]]:
941+ """
942+ Pattern match the current term against a pattern with variables vs.
943+ Leaves the zipper in a context state.
944+ This can be used to replace but rebuild using original context
945+
946+ >>> x,y,z,a,b,c = smt.Ints("x y z a b c")
947+ >>> zip = Zipper([], x + smt.Lambda([y], y*z)[x])
948+ >>> (subst := zip.pmatch([a,b], a*b))
949+ {b: z, a: Y!...}
950+ >>> zip.t = smt.IntVal(1) * subst[a]
951+ >>> zip.rebuild()
952+ x + Lambda(y, 1*y)[x]
953+ """
954+ subst = pmatch (vs , pat , self .t )
955+ if subst is not None :
956+ return subst
957+ for t in self :
958+ subst = pmatch (vs , pat , t )
959+ if subst is not None :
960+ return subst
961+
895962 def __hash__ (self ):
896963 """
897964 Warning: If you are hashing Zippers, make sure you are copying them.
0 commit comments