@@ -844,6 +844,33 @@ def boolsimp(self) -> "ProofState":
844844 self .goals [- 1 ] = goalctx ._replace (ctx = newctx , goal = newgoal )
845845 return self
846846
847+ def forward (self , n : int | smt .BoolRef ) -> smt .BoolRef :
848+ """
849+ Remove the hypothesis of an implication in the context
850+
851+ >>> p,q,r = smt.Bools("p q r")
852+ >>> l = Lemma(smt.Implies(smt.And(p, smt.Implies(p, q)), r))
853+ >>> _ = l.intros()
854+ >>> l.split(at=0)
855+ [p, Implies(p, q)] ?|= r
856+ >>> l.forward(1)
857+ q
858+ >>> l
859+ [p, q] ?|= r
860+ """
861+ # TODO: extra by paramters? Give an exact index to use modus ponens?
862+ goalctx = self .top_goal ()
863+ (at , formula ) = goalctx .ctx_find (n )
864+ if smt .is_implies (formula ):
865+ hyp , conc = formula .children ()
866+ self .have (hyp , by = [])
867+ self .goals [- 1 ] = goalctx ._replace (
868+ ctx = goalctx .ctx [:at ] + [conc ] + goalctx .ctx [at + 1 :]
869+ )
870+ return conc
871+ else :
872+ raise ValueError ("forward failed. Not an implication" , formula )
873+
847874 def emt (self ):
848875 """
849876 Use egraph based equality modulo theories to simplify the goal.
@@ -1059,17 +1086,7 @@ def obtain(self, n: int | smt.QuantifierRef) -> smt.ExprRef | list[smt.ExprRef]:
10591086 """
10601087 goalctx = self .top_goal ()
10611088 ctx , goal = goalctx .ctx , goalctx .goal
1062- if isinstance (n , smt .QuantifierRef ):
1063- for i , f in enumerate (ctx ):
1064- if f .eq (n ):
1065- n = i
1066- break
1067- else :
1068- raise ValueError ("obtain failed. Formula not in context" , n )
1069- assert isinstance (n , int )
1070- if n < 0 :
1071- n = len (ctx ) + n
1072- formula = ctx [n ]
1089+ n , formula = goalctx .ctx_find (n )
10731090 if isinstance (formula , smt .QuantifierRef ) and formula .is_exists ():
10741091 self .pop_goal ()
10751092 fs , obtain_lemma = kd .kernel .obtain (formula )
@@ -1123,33 +1140,36 @@ def contract(self) -> "ProofState":
11231140 self .goals [- 1 ] = self .goalctx ._replace (ctx = newctx )
11241141 return self
11251142
1126- def specialize (self , n : int | smt .QuantifierRef , * ts ) :
1143+ def specialize (self , n : int | smt .QuantifierRef , * ts , keep = False ) -> smt . BoolRef :
11271144 """
11281145 Instantiate a universal quantifier in the context.
11291146
11301147 >>> x,y = smt.Ints("x y")
11311148 >>> l = Lemma(smt.Implies(smt.ForAll([x],x == y), True))
1132- >>> l.intros()
1149+ >>> hyp = l.intros()
1150+ >>> hyp
11331151 ForAll(x, x == y)
11341152 >>> l
11351153 [ForAll(x, x == y)] ?|= True
1136- >>> l.specialize(0, smt.IntVal(42))
1137- [ForAll(x, x == y), 42 == y] ?|= True
1154+ >>> l.specialize(hyp, smt.IntVal(42))
1155+ 42 == y
1156+ >>> l
1157+ [42 == y] ?|= True
11381158 """
11391159 goalctx = self .top_goal ()
1140- if isinstance (n , smt .QuantifierRef ):
1141- for i , f in enumerate (goalctx .ctx ):
1142- if f .eq (n ):
1143- n = i
1144- break
1145- else :
1146- raise ValueError ("Specialize failed. Formula not in context" , n )
1147- thm = goalctx .ctx [n ]
1160+ (n , thm ) = goalctx .ctx_find (n )
11481161 if isinstance (thm , smt .QuantifierRef ) and thm .is_forall ():
11491162 l = kd .kernel .specialize (ts , thm )
11501163 self .add_lemma (l )
1151- self .goals [- 1 ] = goalctx ._replace (ctx = goalctx .ctx + [l .thm .arg (1 )])
1152- return self
1164+ # kernel.specialize returns Implies(forall x, P, P[t/x])
1165+ newformula = l .thm .arg (1 )
1166+ if keep :
1167+ self .goals [- 1 ] = goalctx ._replace (ctx = goalctx .ctx + [newformula ])
1168+ else :
1169+ self .goals [- 1 ] = goalctx ._replace (
1170+ ctx = goalctx .ctx [:n ] + [newformula ] + goalctx .ctx [n + 1 :]
1171+ )
1172+ return newformula
11531173 else :
11541174 foralls = {
11551175 n : formula
@@ -1184,6 +1204,28 @@ def ext(self, at=None):
11841204 else :
11851205 raise ValueError ("Ext failed. Target is not an equality" , target )
11861206
1207+ def andE (self , n : int | smt .BoolRef ) -> list [smt .BoolRef ]:
1208+ """
1209+ Eliminate an `And` in the context.
1210+
1211+ >>> p,q = smt.Bools("p q")
1212+ >>> l = Lemma(smt.Implies(smt.And(p, q), p))
1213+ >>> _ = l.intros()
1214+ >>> p,q = l.andE(0)
1215+ >>> l
1216+ [p, q] ?|= p
1217+ """
1218+ goalctx = self .top_goal ()
1219+ (at , formula ) = goalctx .ctx_find (n )
1220+ if smt .is_and (formula ):
1221+ children = formula .children ()
1222+ self .goals [- 1 ] = goalctx ._replace (
1223+ ctx = goalctx .ctx [:at ] + children + goalctx .ctx [at + 1 :]
1224+ )
1225+ return children
1226+ else :
1227+ raise ValueError ("andE failed. Not an and" , formula )
1228+
11871229 def split (self , at = None ) -> "ProofState" :
11881230 """
11891231 `split` breaks apart an `And` or bi-implication `==` goal.
@@ -1243,21 +1285,22 @@ def cb():
12431285 else :
12441286 (at , hyp ) = goalctx .ctx_find (at )
12451287 if smt .is_or (hyp ):
1288+ # Make N new goals for each disjunct
12461289 self .pop_goal ()
12471290 for c in hyp .children ():
12481291 self .goals .append (
12491292 goalctx ._replace (ctx = ctx [:at ] + [c ] + ctx [at + 1 :], goal = goal )
12501293 )
1251- elif smt .is_and (hyp ):
1294+ elif smt .is_and (hyp ): # TODO: phase this out in favor of andE.
12521295 self .pop_goal ()
12531296 self .goals .append (
12541297 goalctx ._replace (
12551298 ctx = ctx [:at ] + ctx [at ].children () + ctx [at + 1 :], goal = goal
12561299 )
12571300 )
1301+ return self
12581302 else :
12591303 raise ValueError ("Split failed on" , ctx [at ], "in context" , ctx )
1260- return self
12611304
12621305 def left (self , n = 0 ):
12631306 """
0 commit comments