@@ -432,81 +432,88 @@ def Enum(name: str, args: str | list[str]) -> smt.DatatypeSortRef:
432432# smt.DatatypeRef.rel = lambda self, *args: self.rel(*args)
433433
434434
435- def InductiveRel (name : str , * params : smt .ExprRef ) -> smt .Datatype :
436- """Define an inductive type of evidence and a relation the recurses on that evidence
437-
438- >>> x = smt.Int("x")
439- >>> Even = InductiveRel("Even", x)
440- >>> Even.declare("Ev_Z", pred = x == 0)
441- >>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
442- >>> Even = Even.create()
443- >>> ev = smt.Const("ev", Even)
444- >>> ev.rel(4)
445- even(4)[ev]
446- >>> ev(4)
447- even(4)[ev]
448- >>> Even(4)
449- even(4)
435+ class _InductiveRelDatatype (kd .kernel ._InductiveDatatype ):
436+ """
437+ Subclass of _InductiveDatatype for inductive relations.
438+ This replaces the previous approach of monkey-patching declare and create methods.
450439 """
451440
452- dt = Inductive (name )
453-
454- relname = name .lower ()
455- olddeclare = dt .declare
456- preds = [] # tuck away extra predicate
457-
458- def declare (
459- name , * args , pred = None
460- ): # TODO: would it ever make sense to not have a pred?
461- olddeclare (name , * args )
462- preds .append (pred )
463-
464- dt .declare = declare
465-
466- oldcreate = dt .create
467-
468- def create_relation (dt ):
469- """
470- When inductive is done being defined, call this function
471- """
472- ev = smt .FreshConst (dt , prefix = name .lower ())
473- rel = smt .Function (relname , dt , * [x .sort () for x in params ], smt .BoolSort ())
441+ def __init__ (self , name : str , * params : smt .ExprRef ):
442+ """Initialize an inductive relation datatype."""
443+ super ().__init__ (name )
444+ self ._params = params
445+ self ._relname = name .lower ()
446+ self ._preds = [] # Store predicates for each constructor
447+
448+ def declare (self , name , * args , pred = None ):
449+ """Override declare to also store predicates."""
450+ super ().declare (name , * args )
451+ self ._preds .append (pred )
452+
453+ def _create_relation (self , dt ):
454+ """Create the relation definition for this inductive type."""
455+ ev = smt .FreshConst (dt , prefix = self ._relname )
456+ rel = smt .Function (
457+ self ._relname , dt , * [x .sort () for x in self ._params ], smt .BoolSort ()
458+ )
474459 cases = []
475460 for i in range (dt .num_constructors ()):
476461 precond = dt .recognizer (i )(ev ) # recognize case of the evidence
477- pred = preds [i ] # In this case, this predicate should be true
462+ pred = self . _preds [i ] # In this case, this predicate should be true
478463 if pred is None :
479464 res = smt .BoolVal (True )
480465 elif isinstance (pred , smt .ExprRef ):
481466 res = pred
482467 else :
483- args = [dt .accessor (i , j )(ev ) for j in range (dt .constructor (i ).arity ())]
468+ args = [
469+ dt .accessor (i , j )(ev ) for j in range (dt .constructor (i ).arity ())
470+ ]
484471 res = pred (* args )
485472 cases .append ((precond , res ))
486- rel = kd .define (relname , list (params ), smt .Lambda ([ev ], kd .cond (* cases )))
473+ rel = kd .define (
474+ self ._relname , list (self ._params ), smt .Lambda ([ev ], kd .cond (* cases ))
475+ )
487476 return rel
488477
489- def create ():
490- dt = oldcreate ()
491- if any (p is not None for p in preds ):
478+ def create (self ):
479+ """Override create to also set up the relation."""
480+ dt = super ().create ()
481+ if any (p is not None for p in self ._preds ):
492482 dtrel = smt .Function (
493- relname , * [x .sort () for x in params ], smt .ArraySort (dt , smt .BoolSort ())
483+ self ._relname ,
484+ * [x .sort () for x in self ._params ],
485+ smt .ArraySort (dt , smt .BoolSort ()),
494486 )
495487 kd .notation .call .register (dt , lambda self , * args : dtrel (* args )[self ])
496488 rel .register (
497489 dt , lambda self , * args : dtrel (* args )[self ]
498490 ) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.
499- dtrel = create_relation (dt )
491+ dtrel = self . _create_relation (dt )
500492 dt .rel = dtrel
501- # ev = smt.FreshConst(dt, prefix=name.lower())
502- call_dict [dt ] = dtrel # lambda *args: smt.Lambda([ev], ev.rel(*args))
493+ call_dict [dt ] = dtrel
503494
504- if len (params ) == 0 :
495+ if len (self . _params ) == 0 :
505496 kd .notation .wf .register (dt , lambda x : x .rel ())
506497 return dt
507498
508- dt .create = create
509- return dt
499+
500+ def InductiveRel (name : str , * params : smt .ExprRef ) -> _InductiveRelDatatype :
501+ """Define an inductive type of evidence and a relation the recurses on that evidence
502+
503+ >>> x = smt.Int("x")
504+ >>> Even = InductiveRel("Even", x)
505+ >>> Even.declare("Ev_Z", pred = x == 0)
506+ >>> Even.declare("Ev_SS", ("sub2_evidence", Even), pred = lambda evid: evid.rel(x-2))
507+ >>> Even = Even.create()
508+ >>> ev = smt.Const("ev", Even)
509+ >>> ev.rel(4)
510+ even(4)[ev]
511+ >>> ev(4)
512+ even(4)[ev]
513+ >>> Even(4)
514+ even(4)
515+ """
516+ return _InductiveRelDatatype (name , * params )
510517
511518
512519def inj_lemmas (dt : smt .DatatypeSortRef ) -> list [kd .kernel .Proof ]:
0 commit comments