Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Strata.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import Strata.Backends.CBMC
import Strata.DL.Imperative.CFGToCProverGOTO
import Strata.DL.Imperative.ToCProverGOTO
import Strata.DL.SMT.Denote
import Strata.DL.SMT.FactoryCorrect
import Strata.DL.SMT.Translate

/- Code Transforms — additional -/
Expand Down
54 changes: 39 additions & 15 deletions Strata/DL/SMT/Denote.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ currently supported. The core entry point is `denoteTerm`, which builds a
`TermDenoteResult` describing both the type of a term and a semantic interpreter
for it. The surrounding infrastructure tracks the well-formedness of
term and uninterpreted-function contexts so that evaluation is safe.

The denotation uses propositional extensionality (`propext`) and
`Classical.propDecidable` (excluded middle) to make `if`-then-`else` over
`Prop`-valued conditions definable. Downstream correctness proofs
(see `FactoryCorrect.lean`) inherit these dependencies.
-/

open Strata.SMT
Expand Down Expand Up @@ -137,7 +142,7 @@ def substituteIFIS (isctx : ISContext) (iF : Core.SMT.IF) : Core.SMT.IF :=
mutual

/-- Interpret primitive SMT types as Lean types, when supported. -/
def denotePrimSort (sctx : SortContext) (pty : TermPrimType) : Option (SortDenoteResult sctx) := do
@[expose] def denotePrimSort (sctx : SortContext) (pty : TermPrimType) : Option (SortDenoteResult sctx) := do
match pty with
| .bool => return fun _ => Prop
| .int => return fun _ => Int
Expand Down Expand Up @@ -191,7 +196,7 @@ Interpret an SMT `TermType` as a Lean `Type`, when supported.

Returns `none` when we lack an interpretation (e.g. for reals).
-/
def denoteSort (sctx : SortContext) (ty : TermType) : Option (SortDenoteResult sctx) := do
@[expose] def denoteSort (sctx : SortContext) (ty : TermType) : Option (SortDenoteResult sctx) := do
match ty with
| .prim pty => denotePrimSort sctx pty
| .option ty =>
Expand Down Expand Up @@ -256,6 +261,7 @@ theorem denoteFunSortCons_isSome (h : (denoteFunSort sctx (a :: as) out).isSome)
have ⟨h1 , h2⟩ := (Option.any_eq_true_iff_get _ _).mp h
exact ⟨h1, h2⟩


theorem arrow_of_denoteFunSortCons_isSome (h : (denoteFunSort sctx (a :: as) out).isSome) :
have has := denoteFunSortCons_isSome h
(denoteFunSort sctx (a :: as) out).get h sΓ =
Expand Down Expand Up @@ -523,7 +529,7 @@ def bindExistsVar : QuantVarBinder := fun {n} {ty} ctx hTy =>
{ sΓ := tdi.sΓ, hsΓ := tdi.hsΓ, tΓ := { ufs := tdi.tΓ.ufs, vs := vΓ' }, htΓ := { hv := hv', huf := tdi.htΓ.huf } }
ft' tdi'

def buildQuant (bindVar : QuantVarBinder) (ctx : Context) (vs : List TermVar)
@[expose] def buildQuant (bindVar : QuantVarBinder) (ctx : Context) (vs : List TermVar)
Comment thread
tautschnig marked this conversation as resolved.
Outdated
(hTys : (denoteFunSort ctx.sctx vs (.prim .bool)).isSome)
(bodyFt : TermDenoteInput { sctx := ctx.sctx, tctx := { vs := vs.reverse ++ ctx.tctx.vs, ufs := ctx.tctx.ufs } } → Prop)
(tdi : TermDenoteInput ctx)
Expand All @@ -538,20 +544,21 @@ def buildQuant (bindVar : QuantVarBinder) (ctx : Context) (vs : List TermVar)
let ft' := buildQuant bindVar ctx' vs hTys' (hvs ▸ bodyFt)
bindVar (n := n) (ty := ty) ctx (denoteFunSortCons_isSome hTys).left ft' tdi

def buildExists (ctx : Context) (vs : List TermVar)
@[expose] def buildExists (ctx : Context) (vs : List TermVar)
(hTys : (denoteFunSort ctx.sctx vs (.prim .bool)).isSome)
(bodyFt : TermDenoteInput { sctx := ctx.sctx, tctx := { vs := vs.reverse ++ ctx.tctx.vs, ufs := ctx.tctx.ufs } } → Prop)
(tdi : TermDenoteInput ctx)
: Prop :=
buildQuant bindExistsVar ctx vs hTys bodyFt tdi

def buildForall (ctx : Context) (vs : List TermVar)
@[expose] def buildForall (ctx : Context) (vs : List TermVar)
(hTys : (denoteFunSort ctx.sctx vs (.prim .bool)).isSome)
(bodyFt : TermDenoteInput { sctx := ctx.sctx, tctx := { vs := vs.reverse ++ ctx.tctx.vs, ufs := ctx.tctx.ufs } } → Prop)
(tdi : TermDenoteInput ctx)
: Prop :=
buildQuant bindForallVar ctx vs hTys bodyFt tdi


mutual

/-
Expand All @@ -563,7 +570,7 @@ Noncomputable because of `ite` case. Two conditions are needed to make this func
Attempt to interpret a single SMT term under `ctx`, returning its Lean type
and semantics when successful.
-/
noncomputable def denoteTerm (ctx : Context) (t : Term) : Option (TermDenoteResult ctx) := do
@[expose] noncomputable def denoteTerm (ctx : Context) (t : Term) : Option (TermDenoteResult ctx) := do
match t with
-- Variable lookup: if `v` is declared in `ctx.tctx.vs` and its sort can be
-- interpreted, return the corresponding semantic value from `tdi.tΓ.vs`.
Expand Down Expand Up @@ -877,15 +884,15 @@ noncomputable def denoteTerm (ctx : Context) (t : Term) : Option (TermDenoteResu
/--
Interpret every term in a list, short-circuiting if any sub-term fails.
-/
noncomputable def denoteTerms (ctx : Context) (ts : List Term) : Option (List (TermDenoteResult ctx)) := do
@[expose] noncomputable def denoteTerms (ctx : Context) (ts : List Term) : Option (List (TermDenoteResult ctx)) := do
match ts with
| [] => return []
| a :: as =>
let a ← denoteTerm ctx a
let as ← denoteTerms ctx as
return a :: as

noncomputable def leftAssoc (ctx : Context) (ty : TermType) (h : (denoteSort ctx.sctx ty).isSome)
@[expose] noncomputable def leftAssoc (ctx : Context) (ty : TermType) (h : (denoteSort ctx.sctx ty).isSome)
(op : (sdi : SortDenoteInput ctx.sctx) → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi)
(ts : List (TermDenoteResult ctx)) : Option (TermDenoteResult ctx) := do
let t₁ :: t₂ :: ts := ts | none
Expand All @@ -908,7 +915,7 @@ where
else
none

noncomputable def rightAssoc (ctx : Context) (ty : TermType) (h : (denoteSort ctx.sctx ty).isSome)
@[expose] noncomputable def rightAssoc (ctx : Context) (ty : TermType) (h : (denoteSort ctx.sctx ty).isSome)
(op : (sdi : SortDenoteInput ctx.sctx) → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi)
(ts : List (TermDenoteResult ctx)) : Option (TermDenoteResult ctx) := do
let ft ← go ts
Expand All @@ -935,7 +942,7 @@ where
else
none

noncomputable def chainable (ctx ty h)
@[expose] noncomputable def chainable (ctx ty h)
(op : (sdi : SortDenoteInput ctx.sctx) → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi → Prop)
(ts : List (TermDenoteResult ctx)) : Option (TermDenoteResult ctx) := do
let t₁ :: t₂ :: ts := ts | none
Expand All @@ -949,7 +956,7 @@ noncomputable def chainable (ctx ty h)
else
none

noncomputable def chainable.go (ctx ty h)
@[expose] noncomputable def chainable.go (ctx ty h)
(op : (sdi : SortDenoteInput ctx.sctx) → (denoteSort ctx.sctx ty).get h sdi → (denoteSort ctx.sctx ty).get h sdi → Prop)
(ft : TermDenoteInput ctx → Prop) (ft₁ : (tdi : TermDenoteInput ctx) → (denoteSort ctx.sctx ty).get h ⟨tdi.sΓ, tdi.hsΓ⟩)
(ts : List (TermDenoteResult ctx)) : Option (TermDenoteResult ctx) := do match ts with
Expand All @@ -963,22 +970,38 @@ noncomputable def chainable.go (ctx ty h)

end


/--
Interpret a ground boolean term in the empty context.
-/
@[simp]
noncomputable def denoteBoolTermAux (t : Term) : Option Prop := do
@[expose, simp] noncomputable def denoteBoolTermAux (t : Term) : Option Prop := do
let some ⟨.prim .bool, _, fi⟩ := denoteTerm {} t | none
return fi ⟨[], { h := rfl, ha := fun _ hi => nomatch hi }, ⟨[], []⟩, ⟨{ h := rfl, ha := fun _ hi => nomatch hi }, { h := rfl, ha := fun _ hi => nomatch hi }⟩⟩

/--
Interpret a ground integer term in the empty context.
-/
@[simp]
noncomputable def denoteIntTermAux (t : Term) : Option Int := do
@[expose, simp] noncomputable def denoteIntTermAux (t : Term) : Option Int := do
let some ⟨.prim .int, _, fi⟩ := denoteTerm {} t | none
return fi ⟨[], { h := rfl, ha := fun _ hi => nomatch hi }, ⟨[], []⟩, ⟨{ h := rfl, ha := fun _ hi => nomatch hi }, { h := rfl, ha := fun _ hi => nomatch hi }⟩⟩

/--
Interpret a ground bitvector term in the empty context.
-/
@[expose, simp] noncomputable def denoteBVTermAux (n : Nat) (t : Term) : Option (BitVec n) := do
let some ⟨.prim (.bitvec m), _, fi⟩ := denoteTerm {} t | none
if h : m = n then
return h ▸ fi ⟨[], { h := rfl, ha := fun _ hi => nomatch hi }, ⟨[], []⟩, ⟨{ h := rfl, ha := fun _ hi => nomatch hi }, { h := rfl, ha := fun _ hi => nomatch hi }⟩⟩
else
none

/--
Interpret a ground string term in the empty context.
-/
@[expose, simp] noncomputable def denoteStringTermAux (t : Term) : Option String := do
let some ⟨.prim .string, _, fi⟩ := denoteTerm {} t | none
return fi ⟨[], { h := rfl, ha := fun _ hi => nomatch hi }, ⟨[], []⟩, ⟨{ h := rfl, ha := fun _ hi => nomatch hi }, { h := rfl, ha := fun _ hi => nomatch hi }⟩⟩

/--
Eliminate one uninterpreted sort binder by quantifying over all of its
semantic realizations and extending the sort environment accordingly.
Expand Down Expand Up @@ -1228,3 +1251,4 @@ noncomputable def denoteQuery (ctx : Core.SMT.Context) (assums : List Term) (con
let ufs := ctx.ufs.toList.reverse
let ifs := ctx.ifs.toList.reverse
(denoteBoolTermFromContext uss iss ufs ifs t).map PLift.down

43 changes: 41 additions & 2 deletions Strata/DL/SMT/Factory.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
-/
module

import all Strata.DL.Util.BitVec
public import Strata.DL.Util.BitVec
public import Strata.DL.SMT.Function
public import Strata.DL.SMT.Op
public import Strata.DL.SMT.Term
public import Strata.DL.SMT.TermType


public section
@[expose] public section
/-!
Based on Cedar's Term language.
(https://github.com/cedar-policy/cedar-spec/blob/main/cedar-lean/Cedar/SymCC/Factory.lean)
Expand Down Expand Up @@ -40,12 +40,15 @@ namespace Factory

---------- Term constructors ----------

-- Correctness: `Factory.noneOf_correct`
def noneOf (ty : TermType) : Term := .none ty

-- Correctness: `Factory.someOf_correct`
def someOf (t : Term) : Term := .some t

---------- SMTLib core theory of equality with uninterpreted functions (`UF`) ----------

-- Correctness: `Factory.not_correct`
def not : Term → Term
| .prim (.bool b) => ! b
| .app .not [t'] _ => t'
Expand All @@ -56,6 +59,7 @@ def opposites : Term → Term → Bool
| .app .not [t₁] _, t₂ => t₁ = t₂
| _, _ => false

-- Correctness: `Factory.and_correct`
def and (t₁ t₂ : Term) : Term :=
if t₁ = t₂ || t₂ = true
then t₁
Expand All @@ -65,6 +69,7 @@ def and (t₁ t₂ : Term) : Term :=
then false
else .app .and [t₁, t₂] .bool

-- Correctness: `Factory.or_correct`
def or (t₁ t₂ : Term) : Term :=
if t₁ = t₂ || t₂ = false
then t₁
Expand All @@ -74,9 +79,12 @@ def or (t₁ t₂ : Term) : Term :=
then true
else .app .or [t₁, t₂] .bool

-- Correctness: `Factory.implies_correct`
def implies (t₁ t₂ : Term) : Term :=
or (not t₁) t₂

-- Correctness: `Factory.eq_correct_bool`, `Factory.eq_correct_int`,
-- `Factory.eq_correct_bv`, `Factory.eq_correct_string`
def eq (t₁ t₂ : Term) : Term :=
if t₁ = t₂
then true
Expand All @@ -88,6 +96,8 @@ def eq (t₁ t₂ : Term) : Term :=
| .none _, .some _ => false
| _, _ => .app .eq [t₁, t₂] .bool

-- Correctness: `Factory.ite_correct_bool`, `Factory.ite_correct_int`,
-- `Factory.ite_correct_bv`, `Factory.ite_correct_string`
def ite (t₁ t₂ t₃ : Term) : Term :=
if t₁ = true || t₂ = t₃
then t₂
Expand All @@ -113,6 +123,7 @@ Returns the result of applying function to a list of terms.

(TODO) Arity check?
-/
-- Correctness: `Factory.app_uf_correct`
def app : Function → List Term → Term
| .uf f, ts => .app (.uf f) ts f.out

Expand All @@ -129,6 +140,7 @@ theorem mkSimpleTriggerIsSimple: isSimpleTrigger (mkSimpleTrigger x ty) := by
simp [isSimpleTrigger, mkSimpleTrigger]

-- Note: we could coalesce nested quantifiers here, since SMT-Lib allows multiple variables to be bound at once.
-- TODO: Its correctness could not be proven due to its complexity. Contribution is welcome
Comment thread
tautschnig marked this conversation as resolved.
def quant (qk : QuantifierKind) (x : String) (ty : TermType) (tr : Term) (e : Term) : Term :=
-- Check if we can coalesce with a nested quantifier
match e with
Expand All @@ -148,10 +160,12 @@ def quant (qk : QuantifierKind) (x : String) (ty : TermType) (tr : Term) (e : Te

---------- SMTLib theory of integer numbers (`Ints`) ----------

-- Correctness: `Factory.intNeg_correct`
def intNeg : Term → Term
| .prim (.int i) => i.neg
| t => .app .neg [t] t.typeOf

-- Correctness: `Factory.intAbs_correct`
def intAbs : Term → Term
| .prim (.int i) => Int.ofNat i.natAbs
| t => .app .abs [t] t.typeOf
Expand All @@ -161,20 +175,29 @@ def intapp (op : Op) (fn : Int → Int → Int) (t₁ t₂ : Term) : Term :=
| .prim (.int i₁), .prim (.int i₂) => fn i₁ i₂
| _, _ => .app op [t₁, t₂] t₁.typeOf

-- Correctness: `Factory.intSub_correct`
def intSub := intapp .sub Int.sub
-- Correctness: `Factory.intAdd_correct`
def intAdd := intapp .add Int.add
-- Correctness: `Factory.intMul_correct`
def intMul := intapp .mul Int.mul
-- Correctness: `Factory.intDiv_correct`
def intDiv := intapp .div Int.ediv
-- Correctness: `Factory.intMod_correct`
def intMod := intapp .mod Int.emod

def intcmp (op : Op) (fn : Int → Int → Bool) (t₁ t₂ : Term) : Term :=
match t₁, t₂ with
| .prim (.int i₁), .prim (.int i₂) => fn i₁ i₂
| _, _ => .app op [t₁, t₂] .bool

-- Correctness: `Factory.intLe_correct`
def intLe := intcmp .le (λ i₁ i₂ => i₁ <= i₂)
-- Correctness: `Factory.intLt_correct`
def intLt := intcmp .lt (λ i₁ i₂ => i₁ < i₂)
-- Correctness: `Factory.intGe_correct`
def intGe := intcmp .ge (λ i₁ i₂ => i₁ >= i₂)
-- Correctness: `Factory.intGt_correct`
def intGt := intcmp .gt (λ i₁ i₂ => i₁ > i₂)

---------- SMTLib theory of finite bitvectors (`BV`) ----------
Expand All @@ -184,6 +207,7 @@ def intGt := intcmp .gt (λ i₁ i₂ => i₁ > i₂)
-- approach is sufficient for the strong PE property we care about: if given a
-- fully concrete input, the symbolic evaluator returns a fully concrete output.

-- Correctness: `Factory.bvneg_correct`
def bvneg : Term → Term
| .prim (.bitvec b) => b.neg
| t => .app .bvneg [t] t.typeOf
Expand All @@ -195,11 +219,16 @@ def bvapp (op : Op) (fn : ∀ {n}, BitVec n → BitVec n → BitVec n) (t₁ t
| _, _ =>
.app op [t₁, t₂] t₁.typeOf

-- Correctness: `Factory.bvadd_correct`
def bvadd := bvapp .bvadd BitVec.add
-- Correctness: `Factory.bvsub_correct`
def bvsub := bvapp .bvsub BitVec.sub
-- Correctness: `Factory.bvmul_correct`
def bvmul := bvapp .bvmul BitVec.mul

-- Correctness: `Factory.bvshl_correct`
def bvshl := bvapp .bvshl (λ b₁ b₂ => b₁ <<< b₂)
-- Correctness: `Factory.bvlshr_correct`
def bvlshr := bvapp .bvlshr (λ b₁ b₂ => b₁ >>> b₂)

def bvcmp (op : Op) (fn : ∀ {n}, BitVec n → BitVec n → Bool) (t₁ t₂ : Term) : Term :=
Expand All @@ -209,11 +238,16 @@ def bvcmp (op : Op) (fn : ∀ {n}, BitVec n → BitVec n → Bool) (t₁ t₂ :
| _, _ =>
.app op [t₁, t₂] .bool

-- Correctness: `Factory.bvslt_correct`
def bvslt := bvcmp .bvslt BitVec.slt
-- Correctness: `Factory.bvsle_correct`
def bvsle := bvcmp .bvsle BitVec.sle
-- Correctness: `Factory.bvult_correct`
def bvult := bvcmp .bvult BitVec.ult
-- Correctness: `Factory.bvule_correct`
def bvule := bvcmp .bvule BitVec.ule

-- Correctness: `Factory.bvnego_correct`
def bvnego : Term → Term
| .prim (@TermPrim.bitvec n b₁) => BitVec.overflows n (-b₁.toInt)
| t => .app .bvnego [t] .bool
Expand All @@ -224,14 +258,18 @@ def bvso (op : Op) (fn : Int → Int → Int) (t₁ t₂ : Term) : Term :=
BitVec.overflows n (fn b₁.toInt b₂.toInt)
| _, _ => .app op [t₁, t₂] .bool

-- Correctness: `Factory.bvsaddo_correct`
def bvsaddo := bvso .bvsaddo (· + ·)
-- Correctness: `Factory.bvssubo_correct`
def bvssubo := bvso .bvssubo (· - ·)
-- Correctness: `Factory.bvsmulo_correct`
def bvsmulo := bvso .bvsmulo (· * ·)

/-
Note that BitVec defines zero_extend differently from SMTLib,
so we compensate for the difference in partial evaluation.
-/
-- Correctness: `Factory.zero_extend_correct`
def zero_extend (n : Nat) : Term → Term
| .prim (@TermPrim.bitvec m b) =>
BitVec.zeroExtend (n + m) b
Expand All @@ -243,6 +281,7 @@ def zero_extend (n : Nat) : Term → Term

---------- Core ADT operators with a trusted mapping to SMT ----------

-- Correctness: `Factory.option_get_some_correct`
def option.get : Term → Term
| .some t => t
| t =>
Expand Down
Loading
Loading