@@ -434,23 +434,25 @@ def to_expr(self):
434434 else :
435435 return self .goal
436436
437- def ctx_find (self , n : int | smt .BoolRef ) -> smt .BoolRef :
437+ def ctx_find (self , n : int | smt .BoolRef ) -> tuple [ int , smt .BoolRef ] :
438438 """
439439 Find a hypothesis in the context by index or by matching expression.
440440
441441 >>> x = smt.Int("x")
442442 >>> g = Goal(sig=[], ctx=[x > 0, x < 10], goal=x == 5)
443443 >>> g.ctx_find(0)
444- x > 0
444+ (0, x > 0)
445445 >>> g.ctx_find(x < 10)
446- x < 10
446+ (1, x < 10)
447447 """
448448 if isinstance (n , int ):
449- return self .ctx [n ]
449+ if n < 0 :
450+ n += len (self .ctx )
451+ return n , self .ctx [n ]
450452 else :
451- for h in self .ctx :
453+ for i , h in enumerate ( self .ctx ) :
452454 if h .eq (n ):
453- return h
455+ return i , h
454456 raise KeyError (f"Hypothesis { n } not found in context" )
455457
456458 def proof (self ) -> "ProofState" :
@@ -782,9 +784,7 @@ def simp(self, at=None, unfold=False, path=None) -> "ProofState":
782784 self .goals [- 1 ] = goalctx ._replace (goal = newgoal )
783785 else :
784786 oldctx = goalctx .ctx
785- if at < 0 :
786- at = len (oldctx ) + at
787- old = oldctx [at ]
787+ (at , old ) = goalctx .ctx_find (at )
788788 new = kd .utils .pathmap (smt .simplify , old , path )
789789 if new .eq (old ):
790790 raise ValueError ("Simplify failed. Ctx is already simplified." )
@@ -1239,15 +1239,14 @@ def cb():
12391239 raise ValueError ("Unexpected case in goal for split tactic" , goal )
12401240 return self
12411241 else :
1242- if at < 0 :
1243- at = len (ctx ) + at
1244- if smt .is_or (ctx [at ]):
1242+ (at , hyp ) = goalctx .ctx_find (at )
1243+ if smt .is_or (hyp ):
12451244 self .pop_goal ()
1246- for c in ctx [ at ] .children ():
1245+ for c in hyp .children ():
12471246 self .goals .append (
12481247 goalctx ._replace (ctx = ctx [:at ] + [c ] + ctx [at + 1 :], goal = goal )
12491248 )
1250- elif smt .is_and (ctx [ at ] ):
1249+ elif smt .is_and (hyp ):
12511250 self .pop_goal ()
12521251 self .goals .append (
12531252 goalctx ._replace (
@@ -1323,7 +1322,7 @@ def exists(self, *ts) -> "ProofState":
13231322 return self
13241323
13251324 def rw (
1326- self , rule : kd .kernel .Proof | int , at = None , rev = False , ** kwargs
1325+ self , rule : kd .kernel .Proof | int | smt . BoolRef , at = None , rev = False , ** kwargs
13271326 ) -> "ProofState" :
13281327 """
13291328 `rewrite` allows you to apply rewrite rule (which may either be a Proof or an index into the context) to the goal or to the context.
@@ -1340,8 +1339,10 @@ def rw(
13401339 """
13411340 goalctx = self .top_goal ()
13421341 ctx , goal = goalctx .ctx , goalctx .goal
1343- if isinstance (rule , int ):
1344- rulethm = ctx [rule ]
1342+ if at is not None :
1343+ (at , hyp ) = goalctx .ctx_find (at )
1344+ if isinstance (rule , int ) or isinstance (rule , smt .ExprRef ):
1345+ _ , rulethm = goalctx .ctx_find (rule )
13451346 elif kd .kernel .is_proof (rule ):
13461347 rulethm = rule .thm
13471348 else :
@@ -1365,12 +1366,8 @@ def rw(
13651366 raise ValueError (f"Rewrite tactic failed. Not an equality { rulethm } " )
13661367 if at is None :
13671368 target = goal
1368- elif isinstance (at , int ):
1369- target = ctx [at ]
13701369 else :
1371- raise ValueError (
1372- "Rewrite tactic failed. `at` is not an index into the context"
1373- )
1370+ at , target = goalctx .ctx_find (at )
13741371 t_subst = kd .utils .pmatch_rec (vs , lhs , target )
13751372 if t_subst is None :
13761373 raise ValueError (
@@ -1383,13 +1380,11 @@ def rw(
13831380 target : smt .BoolRef = smt .substitute (target , (lhs1 , rhs1 ))
13841381 if isinstance (rulethm , smt .QuantifierRef ) and rulethm .is_forall ():
13851382 self .add_lemma (kd .kernel .specialize ([subst [v ] for v in vs ], rulethm ))
1386- if not isinstance (rule , int ) and kd .kernel .is_proof ( rule ):
1383+ if isinstance (rule , kd .kernel .Proof ):
13871384 self .add_lemma (rule )
13881385 if at is None :
13891386 self .goals .append (goalctx ._replace (ctx = ctx , goal = target ))
13901387 else :
1391- if at == - 1 :
1392- at = len (ctx ) - 1
13931388 self .goals .append (
13941389 goalctx ._replace (ctx = ctx [:at ] + [target ] + ctx [at + 1 :], goal = goal )
13951390 )
@@ -1474,13 +1469,13 @@ def beta(self, at=None):
14741469 self .goals [- 1 ] = goalctx ._replace (goal = newgoal )
14751470 else :
14761471 oldctx = goalctx .ctx
1477- old = oldctx [ at ]
1472+ at , old = goalctx . ctx_find ( at )
14781473 new = kd .rewrite .beta (old )
14791474 if new .eq (old ):
14801475 raise ValueError (
14811476 "Beta tactic failed. Ctx is already beta reduced." , old
14821477 )
1483- self .add_lemma (kd .kernel .prove (old == new ))
1478+ self .add_lemma (kd .kernel .prove (smt . Eq ( old , new ) ))
14841479 self .goals [- 1 ] = goalctx ._replace (
14851480 ctx = oldctx [:at ] + [new ] + oldctx [at + 1 :]
14861481 )
@@ -1518,7 +1513,7 @@ def unfold(self, *decls: smt.FuncDeclRef, at=None, keep=False) -> "ProofState":
15181513 else :
15191514 self .goals .append (goalctx ._replace (goal = e2 ))
15201515 else :
1521- e = goalctx .ctx [ at ]
1516+ at , e = goalctx .ctx_find ( at )
15221517 trace = []
15231518 e2 = kd .rewrite .unfold (e , decls = decls , trace = trace )
15241519 for lem in trace :
@@ -1639,12 +1634,21 @@ def contra(self):
16391634 )
16401635 return self .top_goal ()
16411636
1642- def clear (self , n : int ):
1637+ def clear (self , n : int | smt . BoolRef ):
16431638 """
16441639 Remove a hypothesis from the context
1640+
1641+ >>> p,q = smt.Bools("p q")
1642+ >>> l = Lemma(smt.Implies(p, q))
1643+ >>> h = l.intros()
1644+ >>> l
1645+ [p] ?|= q
1646+ >>> l.clear(h)
1647+ [] ?|= q
16451648 """
16461649 ctxgoal = self .pop_goal ()
16471650 ctx = ctxgoal .ctx .copy ()
1651+ n , _ = ctxgoal .ctx_find (n )
16481652 ctx .pop (n )
16491653 self .goals .append (ctxgoal ._replace (ctx = ctx ))
16501654 return self .top_goal ()
0 commit comments