44from enum import IntEnum
55import operator as op
66from . import config
7+ from typing import NamedTuple
78
89
910class Calc :
@@ -197,58 +198,191 @@ def simp(t: smt.ExprRef, by: list[kd.kernel.Proof] = [], **kwargs) -> kd.kernel.
197198 return lemma (t == t1 , by = by , ** kwargs )
198199
199200
201+ class Sequent (NamedTuple ):
202+ ctx : list [smt .BoolRef ]
203+ goal : smt .BoolRef
204+
205+ def __repr__ (self ):
206+ return repr (self .ctx ) + " ?|- " + repr (self .goal )
207+
208+
200209class Lemma :
201- # Isar style forward proof
202210 def __init__ (self , goal : smt .BoolRef ):
203- # self.cur_goal = goal
204211 self .lemmas = []
205212 self .thm = goal
206- self .goals = [([], goal )]
213+ self .goals = [Sequent ([], goal )]
214+
215+ def fixes (self ):
216+ ctx , goal = self .goals [- 1 ]
217+ if smt .is_quantifier (goal ) and goal .is_forall ():
218+ self .goals .pop ()
219+ vs , herb_lemma = kd .kernel .herb (goal )
220+ self .lemmas .append (herb_lemma )
221+ self .goals .append (Sequent (ctx , herb_lemma .thm .arg (0 )))
222+ return vs
223+ else :
224+ raise ValueError (f"fixes tactic failed. Not a forall { goal } " )
207225
208226 def intros (self ):
209227 ctx , goal = self .goals .pop ()
210228 if smt .is_quantifier (goal ) and goal .is_forall ():
211229 vs , herb_lemma = kd .kernel .herb (goal )
212230 self .lemmas .append (herb_lemma )
213- self .goals .append ((ctx , herb_lemma .thm .arg (0 )))
214- return vs
231+ self .goals .append (Sequent (ctx , herb_lemma .thm .arg (0 )))
232+ if len (vs ) == 1 :
233+ return vs [0 ]
234+ else :
235+ return vs
215236 elif smt .is_implies (goal ):
216- self .goals .append ((ctx + [goal .arg (0 )], goal .arg (1 )))
217- return self
237+ self .goals .append (Sequent (ctx + [goal .arg (0 )], goal .arg (1 )))
238+ return self .top_goal ()
239+ elif smt .is_not (goal ):
240+ self .goals .append ((ctx + [goal .arg (0 )], smt .BoolVal (False )))
241+ return
242+ else :
243+ raise ValueError ("Intros failed." )
218244
219245 def cases (self , t ):
220246 ctx , goal = self .goals .pop ()
221247 if t .sort () == smt .BoolSort ():
222- self .goals .append ((ctx + [smt .Not (t )], goal ))
223- self .goals .append ((ctx + [t ], goal ))
248+ self .goals .append (Sequent (ctx + [smt .Not (t )], goal ))
249+ self .goals .append (Sequent (ctx + [t ], goal ))
224250 elif isinstance (t , smt .DatatypeRef ):
225251 dsort = t .sort ()
226252 for i in reversed (range (dsort .num_constructors ())):
227- self .goals .append ((ctx + [dsort .recognizer (i )(t )], goal ))
253+ self .goals .append (Sequent (ctx + [dsort .recognizer (i )(t )], goal ))
228254 else :
229255 raise ValueError ("Cases failed. Not a bool or datatype" )
230- return self
256+ return self . top_goal ()
231257
232258 def auto (self ):
233259 ctx , goal = self .goals [- 1 ]
234260 self .lemmas .append (lemma (smt .Implies (smt .And (ctx ), goal )))
235261 self .goals .pop ()
236- return self
262+ return self . top_goal ()
237263
238- def split (self ):
264+ def einstan (self , n ):
239265 ctx , goal = self .goals [- 1 ]
240- if smt .is_and (goal ):
266+ formula = ctx [n ]
267+ if smt .is_quantifier (formula ) and formula .is_exists ():
241268 self .goals .pop ()
242- self .goals .extend ([(ctx , c ) for c in goal .children ()])
269+ fs , einstan_lemma = kd .kernel .einstan (formula )
270+ self .lemmas .append (einstan_lemma )
271+ self .goals .append (
272+ Sequent (ctx [:n ] + [einstan_lemma .thm .arg (1 )] + ctx [n + 1 :], goal )
273+ )
274+ if len (fs ) == 1 :
275+ return fs [0 ]
276+ else :
277+ return fs
278+ else :
279+ raise ValueError ("Einstan failed. Not an exists" )
280+
281+ def split (self , at = None ):
282+ ctx , goal = self .goals [- 1 ]
283+ if at is None :
284+ if smt .is_and (goal ):
285+ self .goals .pop ()
286+ self .goals .extend ([Sequent (ctx , c ) for c in goal .children ()])
287+ if smt .is_eq (goal ):
288+ self .goals .pop ()
289+ self .goals .append (Sequent (ctx , smt .Implies (goal .arg (0 ), goal .arg (1 ))))
290+ self .goals .append (Sequent (ctx , smt .Implies (goal .arg (1 ), goal .arg (0 ))))
291+ else :
292+ raise ValueError ("Split failed" )
293+ else :
294+ if smt .is_or (ctx [at ]):
295+ self .goals .pop ()
296+ for c in ctx [at ].children ():
297+ self .goals .append (Sequent (ctx [:at ] + [c ] + ctx [at + 1 :], goal ))
298+ if smt .is_and (ctx [at ]):
299+ self .goals .pop ()
300+ self .goals .append (
301+ Sequent (ctx [:at ] + ctx [at ].children () + ctx [at + 1 :], goal )
302+ )
303+ else :
304+ raise ValueError ("Split failed" )
305+
306+ def left (self , n = 0 ):
307+ ctx , goal = self .goals [- 1 ]
308+ if smt .is_or (goal ):
309+ if n is None :
310+ n = 0
311+ self .goals [- 1 ] = Sequent (ctx , goal .arg (n ))
312+ return self .top_goal ()
243313 else :
244- raise ValueError ("Split failed. Not an and" )
314+ raise ValueError ("Left failed. Not an or" )
315+
316+ def right (self ):
317+ ctx , goal = self .goals [- 1 ]
318+ if smt .is_or (goal ):
319+ self .goals [- 1 ] = Sequent (ctx , goal .arg (goal .num_args () - 1 ))
320+ return self .top_goal ()
321+ else :
322+ raise ValueError ("Right failed. Not an or" )
245323
246324 def exists (self , * ts ):
247325 ctx , goal = self .goals [- 1 ]
248326 lemma = kd .kernel .forget2 (ts , goal )
249327 self .lemmas .append (lemma )
250- self .goals [- 1 ] = (ctx , lemma .thm .arg (0 ))
251- return self
328+ self .goals [- 1 ] = Sequent (ctx , lemma .thm .arg (0 ))
329+ return self .top_goal ()
330+
331+ def rewrite (self , rule , at = None , rev = False ):
332+ """
333+ `rewrite` allows you to apply rewrite rule (which may either be a Proof or an index into the context) to the goal or to the context.
334+ """
335+ ctx , goal = self .goals [- 1 ]
336+ if isinstance (rule , int ):
337+ rulethm = ctx [rule ]
338+ elif kd .kernel .is_proof (rule ):
339+ rulethm = rule .thm
340+ if smt .is_quantifier (rulethm ) and rulethm .is_forall ():
341+ vs , body = kd .utils .open_binder (rulethm )
342+ else :
343+ vs = []
344+ body = rulethm
345+ if smt .is_eq (body ):
346+ lhs , rhs = body .arg (0 ), body .arg (1 )
347+ if rev :
348+ lhs , rhs = rhs , lhs
349+ else :
350+ raise ValueError (f"Rewrite tactic failed. Not an equality { rulethm } " )
351+ if at is None :
352+ target = goal
353+ elif isinstance (at , int ):
354+ target = ctx [at ]
355+ else :
356+ raise ValueError (
357+ "Rewrite tactic failed. `at` is not an index into the context"
358+ )
359+ subst = kd .utils .pmatch_rec (vs , lhs , target )
360+ if subst is None :
361+ raise ValueError (
362+ f"Rewrite tactic failed to apply lemma { rulethm } to goal { goal } "
363+ )
364+ else :
365+ self .goals .pop ()
366+ lhs1 = smt .substitute (lhs , * [(v , t ) for v , t in subst .items ()])
367+ rhs1 = smt .substitute (rhs , * [(v , t ) for v , t in subst .items ()])
368+ target : smt .BoolRef = smt .substitute (target , (lhs1 , rhs1 ))
369+ self .lemmas .append (kd .kernel .instan2 ([subst [v ] for v in vs ], rulethm ))
370+ if kd .kernel .is_proof (rule ):
371+ self .lemmas .append (rule )
372+ if at is None :
373+ self .goals .append (Sequent (ctx , target ))
374+ else :
375+ self .goals .append (Sequent (ctx [:at ] + [target ] + ctx [at + 1 :], goal ))
376+ return self .top_goal ()
377+
378+ def rw (self , rule , at = None , rev = False ):
379+ return self .rewrite (rule , at = at , rev = rev )
380+
381+ def unfold (self , decl : smt .FuncDeclRef ):
382+ if hasattr (decl , "defn" ):
383+ return self .rewrite (decl .defn )
384+ else :
385+ raise ValueError ("Unfold failed. Not a defined function" )
252386
253387 def apply (self , pf : kd .kernel .Proof , rev = False ):
254388 ctx , goal = self .goals .pop ()
@@ -273,32 +407,38 @@ def apply(self, pf: kd.kernel.Proof, rev=False):
273407 pf1 = kd .kernel .instan ([subst [v ] for v in vs ], pf )
274408 self .lemmas .append (pf1 )
275409 if smt .is_implies (pf1 .thm ):
276- self .goals .append ((ctx , pf1 .thm .arg (0 )))
410+ self .goals .append (Sequent (ctx , pf1 .thm .arg (0 )))
277411 elif smt .is_eq (pf1 .thm ):
278412 if rev :
279- self .goals .append ((ctx , pf1 .thm .arg (0 )))
413+ self .goals .append (Sequent (ctx , pf1 .thm .arg (0 )))
280414 else :
281- self .goals .append ((ctx , pf1 .thm .arg (1 )))
282- return self
415+ self .goals .append (Sequent (ctx , pf1 .thm .arg (1 )))
416+ return self . top_goal ()
283417
284418 def assumption (self ):
285419 ctx , goal = self .goals .pop ()
286420 if any ([goal .eq (h ) for h in ctx ]):
287- return self
421+ return self . top_goal ()
288422 else :
289423 raise ValueError ("Assumption tactic failed" , goal , ctx )
290424
291425 def have (self , conc , ** kwargs ):
292426 ctx , goal = self .goals .pop ()
293427 self .lemmas .append (lemma (smt .Implies (smt .And (ctx ), conc )), ** kwargs )
294- self .goals .append ((ctx + [conc ], conc ))
295- return self
428+ self .goals .append (Sequent (ctx + [conc ], conc ))
429+ return self . top_goal ()
296430
297- def __repr__ (self ):
431+ # TODO
432+ # def search():
433+ # def calc
434+
435+ def top_goal (self ):
298436 if len (self .goals ) == 0 :
299437 return "Nothing to do. Hooray!"
300- ctx , goal = self .goals [- 1 ]
301- return repr (ctx ) + " ?|- " + repr (goal )
438+ return self .goals [- 1 ]
439+
440+ def __repr__ (self ):
441+ return repr (self .top_goal ())
302442
303443 def qed (self ):
304- return lemma (self .thm , by = self .lemmas )
444+ return kd . kernel . lemma (self .thm , by = self .lemmas )
0 commit comments