Skip to content

Commit 6442e1d

Browse files
committed
Option and List helpers.
1 parent f93314c commit 6442e1d

File tree

7 files changed

+304
-14
lines changed

7 files changed

+304
-14
lines changed

examples/soft_found/lf/IndProp.ipynb

Lines changed: 177 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,13 +409,9 @@
409409
"cell_type": "markdown",
410410
"metadata": {},
411411
"source": [
412-
"\n",
413-
"\n",
414-
"\n",
415-
"\n",
416412
"\n",
417413
"(* ================================================================= *)\n",
418-
"(** ** Example: Permutations *)\n",
414+
"# Example: Permutations *)\n",
419415
"\n",
420416
"(** The familiar mathematical concept of _permutation_ also has an\n",
421417
" elegant formulation as an inductive relation. For simplicity,\n",
@@ -453,7 +449,182 @@
453449
" - apply perm3_swap12.\n",
454450
" - apply perm3_swap23. Qed.\n",
455451
"\n",
456-
"(* ================================================================= *)\n",
452+
"(* ================================================================= *)"
453+
]
454+
},
455+
{
456+
"cell_type": "code",
457+
"execution_count": null,
458+
"metadata": {},
459+
"outputs": [
460+
{
461+
"name": "stdout",
462+
"output_type": "stream",
463+
"text": [
464+
"WARNING: Redefining function perm3 from |- ForAll([perm3!150, seq!151, seq!152],\n",
465+
" perm3(perm3!150, seq!151, seq!152) ==\n",
466+
" If(is(perm3_swap12, perm3!150),\n",
467+
" And(Nth(seq!151, 0) == Nth(seq!152, 1),\n",
468+
" Nth(seq!151, 1) == Nth(seq!152, 0),\n",
469+
" Nth(seq!151, 2) == Nth(seq!152, 2)),\n",
470+
" If(is(perm3_swap23, perm3!150),\n",
471+
" And(Nth(seq!151, 0) == Nth(seq!152, 0),\n",
472+
" Nth(seq!151, 1) == Nth(seq!152, 2),\n",
473+
" Nth(seq!151, 2) == Nth(seq!152, 1)),\n",
474+
" If(is(perm3_trans, perm3!150),\n",
475+
" And(perm3(ev1(perm3!150),\n",
476+
" seq!151,\n",
477+
" y(perm3!150)),\n",
478+
" perm3(ev2(perm3!150),\n",
479+
" y(perm3!150),\n",
480+
" seq!152)),\n",
481+
" unreachable!153)))) to ForAll([perm3!154, seq!155, seq!156],\n",
482+
" perm3(perm3!154, seq!155, seq!156) ==\n",
483+
" If(is(perm3_swap12, perm3!154),\n",
484+
" And(Nth(seq!155, 0) == Nth(seq!156, 1),\n",
485+
" Nth(seq!155, 1) == Nth(seq!156, 0),\n",
486+
" Nth(seq!155, 2) == Nth(seq!156, 2)),\n",
487+
" If(is(perm3_swap23, perm3!154),\n",
488+
" And(Nth(seq!155, 0) == Nth(seq!156, 0),\n",
489+
" Nth(seq!155, 1) == Nth(seq!156, 2),\n",
490+
" Nth(seq!155, 2) == Nth(seq!156, 1)),\n",
491+
" If(is(perm3_trans, perm3!154),\n",
492+
" And(perm3(ev1(perm3!154),\n",
493+
" seq!155,\n",
494+
" y(perm3!154)),\n",
495+
" perm3(ev2(perm3!154),\n",
496+
" y(perm3!154),\n",
497+
" seq!156)),\n",
498+
" unreachable!157))))\n",
499+
"[|- ForAll([perm3!154, seq!155, seq!156],\n",
500+
" perm3(perm3!154, seq!155, seq!156) ==\n",
501+
" If(is(perm3_swap12, perm3!154),\n",
502+
" And(Nth(seq!155, 0) == Nth(seq!156, 1),\n",
503+
" Nth(seq!155, 1) == Nth(seq!156, 0),\n",
504+
" Nth(seq!155, 2) == Nth(seq!156, 2)),\n",
505+
" If(is(perm3_swap23, perm3!154),\n",
506+
" And(Nth(seq!155, 0) == Nth(seq!156, 0),\n",
507+
" Nth(seq!155, 1) == Nth(seq!156, 2),\n",
508+
" Nth(seq!155, 2) == Nth(seq!156, 1)),\n",
509+
" If(is(perm3_trans, perm3!154),\n",
510+
" And(perm3(ev1(perm3!154),\n",
511+
" seq!155,\n",
512+
" y(perm3!154)),\n",
513+
" perm3(ev2(perm3!154),\n",
514+
" y(perm3!154),\n",
515+
" seq!156)),\n",
516+
" unreachable!157)))), |- Implies(perm3(perm3_trans(perm3_swap12,\n",
517+
" perm3_swap23,\n",
518+
" Concat(Unit(2),\n",
519+
" Concat(Unit(1), Unit(3)))),\n",
520+
" Concat(Unit(1), Concat(Unit(2), Unit(3))),\n",
521+
" Concat(Unit(2), Concat(Unit(3), Unit(1)))),\n",
522+
" Exists(ev,\n",
523+
" perm3(ev,\n",
524+
" Concat(Unit(1),\n",
525+
" Concat(Unit(2), Unit(3))),\n",
526+
" Concat(Unit(2),\n",
527+
" Concat(Unit(3), Unit(1))))))]\n"
528+
]
529+
},
530+
{
531+
"data": {
532+
"text/html": [
533+
"⊦Exists(ev,\n",
534+
" perm3(ev,\n",
535+
" Concat(Unit(1), Concat(Unit(2), Unit(3))),\n",
536+
" Concat(Unit(2), Concat(Unit(3), Unit(1)))))"
537+
],
538+
"text/plain": [
539+
"|- Exists(ev,\n",
540+
" perm3(ev,\n",
541+
" Concat(Unit(1), Concat(Unit(2), Unit(3))),\n",
542+
" Concat(Unit(2), Concat(Unit(3), Unit(1)))))"
543+
]
544+
},
545+
"execution_count": 2,
546+
"metadata": {},
547+
"output_type": "execute_result"
548+
}
549+
],
550+
"source": [
551+
"from kdrag.all import *\n",
552+
"\n",
553+
"Perm3 = kd.notation.InductiveRel(\"Perm3\", smt.SeqSort(smt.IntSort()), smt.SeqSort(smt.IntSort()), admit=True)\n",
554+
"Perm3.declare(\"perm3_swap12\", pred= lambda s1,s2: smt.And(s1[0] == s2[1], s1[1] == s2[0], s1[2] == s2[2]))\n",
555+
"Perm3.declare(\"perm3_swap23\", pred= lambda s1,s2: smt.And(s1[0] == s2[0], s1[1] == s2[2], s1[2] == s2[1]))\n",
556+
"Perm3.declare(\"perm3_trans\", (\"ev1\", Perm3), (\"ev2\", Perm3), (\"y\", smt.SeqSort(smt.IntSort())), \n",
557+
" pred=lambda ev1,ev2,y,s1,s2: ev1.rel(s1,y) & ev2.rel(y,s2))\n",
558+
"Perm3 = Perm3.create() \n",
559+
"Perm3.rel.defn\n",
560+
"\n",
561+
"ev = smt.Const(\"ev\", Perm3)\n",
562+
"\n",
563+
"def seq(*args):\n",
564+
" return smt.Concat(*[smt.Unit(smt.py2expr(arg)) for arg in args])\n",
565+
"\n",
566+
"l = kd.Lemma(smt.Exists([ev], ev.rel(seq(1,2,3), seq(2,3,1))))\n",
567+
"l.exists(Perm3.perm3_trans(Perm3.perm3_swap12, Perm3.perm3_swap23, seq(2,1,3)))\n",
568+
"#l.auto(by=Perm3.rel.defn)\n",
569+
"#l.lemmas\n",
570+
"l.qed(by=[Perm3.rel.defn])\n",
571+
"\n",
572+
"# TODO: I need to implement unification to repliace their proof.\n"
573+
]
574+
},
575+
{
576+
"cell_type": "code",
577+
"execution_count": null,
578+
"metadata": {},
579+
"outputs": [
580+
{
581+
"ename": "LemmaError",
582+
"evalue": "('In by reasons:', [|- ForAll([perm3!150, seq!151, seq!152],\n perm3(perm3!150, seq!151, seq!152) ==\n If(is(perm3_swap12, perm3!150),\n And(Nth(seq!151, 0) == Nth(seq!152, 1),\n Nth(seq!151, 1) == Nth(seq!152, 0),\n Nth(seq!151, 2) == Nth(seq!152, 2)),\n If(is(perm3_swap23, perm3!150),\n And(Nth(seq!151, 0) == Nth(seq!152, 0),\n Nth(seq!151, 1) == Nth(seq!152, 2),\n Nth(seq!151, 2) == Nth(seq!152, 1)),\n If(is(perm3_trans, perm3!150),\n And(perm3(ev1(perm3!150),\n seq!151,\n y(perm3!150)),\n perm3(ev2(perm3!150),\n y(perm3!150),\n seq!152)),\n unreachable!153))))], 'is not a Proof object')",
583+
"output_type": "error",
584+
"traceback": [
585+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
586+
"\u001b[0;31mLemmaError\u001b[0m Traceback (most recent call last)",
587+
"Cell \u001b[0;32mIn[2], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m l\u001b[38;5;241m.\u001b[39mexists(Perm3\u001b[38;5;241m.\u001b[39mperm3_trans(Perm3\u001b[38;5;241m.\u001b[39mperm3_swap12, Perm3\u001b[38;5;241m.\u001b[39mperm3_swap23, seq(\u001b[38;5;241m2\u001b[39m,\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m3\u001b[39m)))\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m#l.auto(by=Perm3.rel.defn)\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[43ml\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mqed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mby\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mPerm3\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdefn\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
588+
"File \u001b[0;32m~/Documents/python/knuckledragger/kdrag/tactics.py:767\u001b[0m, in \u001b[0;36mLemma.qed\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 766\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mby\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlemmas\n\u001b[0;32m--> 767\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mkd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlemma\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mthm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
589+
"File \u001b[0;32m~/Documents/python/knuckledragger/kdrag/kernel.py:80\u001b[0m, in \u001b[0;36mlemma\u001b[0;34m(thm, by, admit, timeout, dump, solver)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m by:\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(p, __Proof):\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m LemmaError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIn by reasons:\u001b[39m\u001b[38;5;124m\"\u001b[39m, p, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis not a Proof object\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 81\u001b[0m s\u001b[38;5;241m.\u001b[39madd(p\u001b[38;5;241m.\u001b[39mthm)\n\u001b[1;32m 82\u001b[0m s\u001b[38;5;241m.\u001b[39madd(smt\u001b[38;5;241m.\u001b[39mNot(thm))\n",
590+
"\u001b[0;31mLemmaError\u001b[0m: ('In by reasons:', [|- ForAll([perm3!150, seq!151, seq!152],\n perm3(perm3!150, seq!151, seq!152) ==\n If(is(perm3_swap12, perm3!150),\n And(Nth(seq!151, 0) == Nth(seq!152, 1),\n Nth(seq!151, 1) == Nth(seq!152, 0),\n Nth(seq!151, 2) == Nth(seq!152, 2)),\n If(is(perm3_swap23, perm3!150),\n And(Nth(seq!151, 0) == Nth(seq!152, 0),\n Nth(seq!151, 1) == Nth(seq!152, 2),\n Nth(seq!151, 2) == Nth(seq!152, 1)),\n If(is(perm3_trans, perm3!150),\n And(perm3(ev1(perm3!150),\n seq!151,\n y(perm3!150)),\n perm3(ev2(perm3!150),\n y(perm3!150),\n seq!152)),\n unreachable!153))))], 'is not a Proof object')"
591+
]
592+
}
593+
],
594+
"source": []
595+
},
596+
{
597+
"cell_type": "code",
598+
"execution_count": null,
599+
"metadata": {},
600+
"outputs": [],
601+
"source": [
602+
"#l.qed(by=Perm3.rel.defn)\n",
603+
"\"\"\"\n",
604+
"l.unfold(Perm3.rel)\n",
605+
"l.simp()\n",
606+
"l.split()\n",
607+
"\n",
608+
"l.unfold(Perm3.rel)\n",
609+
"l.simp()\n",
610+
"l.auto()\n",
611+
"\n",
612+
"l.auto(by=Perm3.rel.defn)\n",
613+
"l.qed()\n",
614+
"\"\"\"\n",
615+
"\n",
616+
"#l.auto(by=Perm3.rel.defn)"
617+
]
618+
},
619+
{
620+
"cell_type": "markdown",
621+
"metadata": {},
622+
"source": [
623+
"\n",
624+
"\n",
625+
"\n",
626+
"\n",
627+
"\n",
457628
"(** ** Example: Evenness (yet again) *)\n",
458629
"\n",
459630
"(** We've already seen two ways of stating a proposition that a number\n",

kdrag/notation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def define(self, args, body):
116116
div = SortDispatch(name="div_")
117117
smt.ExprRef.__truediv__ = lambda x, y: div(x, y) # type: ignore
118118

119-
and_ = SortDispatch()
119+
and_ = SortDispatch(name="and_")
120120
smt.ExprRef.__and__ = lambda x, y: and_(x, y) # type: ignore
121121

122-
or_ = SortDispatch()
122+
or_ = SortDispatch(name="or_")
123123
smt.ExprRef.__or__ = lambda x, y: or_(x, y) # type: ignore
124124

125125
invert = SortDispatch(name="invert")
@@ -528,7 +528,7 @@ def create_relation(dt):
528528

529529
def create():
530530
dt = oldcreate()
531-
dtrel = smt.Function(name, dt, *param_sorts, smt.BoolSort())
531+
dtrel = smt.Function(relname, dt, *param_sorts, smt.BoolSort())
532532
rel.register(
533533
dt, lambda *args: dtrel(*args)
534534
) # doing this here let's us tie the knot inside of lambdas and refer to the predicate.

kdrag/tactics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,9 @@ def qed(self, **kwargs) -> kd.kernel.Proof:
757757
"""
758758
return the actual final `Proof` of the lemma that was defined at the beginning.
759759
"""
760+
760761
if "by" in kwargs:
761-
kwargs["by"] += self.lemmas
762+
kwargs["by"].extend(self.lemmas)
762763
else:
763764
kwargs["by"] = self.lemmas
764765
return kd.kernel.lemma(self.thm, **kwargs)

kdrag/theories/datatypes/list.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,64 @@
11
import kdrag as kd
22
import functools
3+
import kdrag.smt as smt
34

45

56
@functools.cache
6-
def List(sort):
7-
dt = kd.Inductive(f"List<{sort.name()}>")
8-
dt.add_constructor("Nil")
9-
dt.add_constructor("Cons", ("cons", sort), ("tail", dt))
7+
def List(sort: smt.SortRef):
8+
"""
9+
Build List sort
10+
>>> IntList = List(smt.IntSort())
11+
>>> IntList.Cons(1, IntList.Nil)
12+
Cons(1, Nil)
13+
"""
14+
dt = kd.Inductive("List_" + sort.name())
15+
dt.declare("Nil")
16+
dt.declare("Cons", ("head", sort), ("tail", dt))
1017
return dt.create()
18+
19+
20+
def list(*args):
21+
"""
22+
Helper to construct List values
23+
>>> list(1, 2, 3)
24+
Cons(1, Cons(2, Cons(3, Nil)))
25+
"""
26+
if len(args) == 0:
27+
raise ValueError("list() requires at least one argument")
28+
LT = List(smt._py2expr(args[0]).sort())
29+
acc = LT.Nil
30+
for a in reversed(args):
31+
acc = LT.Cons(a, acc)
32+
return acc
33+
34+
35+
def Cons(x, xs):
36+
"""
37+
Helper to construct Cons values
38+
>>> Cons(1, Nil(smt.IntSort()))
39+
Cons(1, Nil)
40+
"""
41+
LT = List(smt._py2expr(x).sort())
42+
return LT.Cons(x, xs)
43+
44+
45+
def Nil(sort: smt.SortRef) -> smt.DatatypeRef:
46+
"""
47+
Helper to construct Nil values
48+
>>> Nil(smt.IntSort())
49+
Nil
50+
"""
51+
return List(sort).Nil
52+
53+
54+
def Unit(x: smt.ExprRef) -> smt.DatatypeRef:
55+
"""
56+
Helper to create Unit values
57+
>>> Unit(42)
58+
Cons(42, Nil)
59+
>>> Unit(42).sort()
60+
List_Int
61+
"""
62+
x = smt._py2expr(x)
63+
LT = List(x.sort())
64+
return LT.Cons(x, LT.Nil)

kdrag/theories/datatypes/option.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import functools
2+
import kdrag.smt as smt
3+
import kdrag as kd
4+
5+
6+
@functools.cache
7+
def Option(T: smt.SortRef, admit=False) -> smt.DatatypeRef:
8+
"""
9+
Define an Option type for a given type T
10+
>>> OInt = Option(smt.IntSort())
11+
>>> OInt.Some(1)
12+
Some(1)
13+
>>> OInt.None_
14+
None_
15+
>>> OInt.Some(1).val
16+
val(Some(1))
17+
"""
18+
Option = kd.Inductive("Option_" + T.name(), admit=admit)
19+
Option.declare("None_")
20+
Option.declare("Some", ("val", T))
21+
Option = Option.create()
22+
return Option
23+
24+
25+
# This should also perhaps be a SortDispatch
26+
def get(x: smt.DatatypeRef, default: smt.ExprRef) -> smt.ExprRef:
27+
"""
28+
Get the value of an Option, or a default value if it is None_
29+
>>> get(Some(42), 0)
30+
If(is(Some, Some(42)), val(Some(42)), 0)
31+
"""
32+
return smt.If(x.is_Some, x.val, default)
33+
34+
35+
# I guess I could make this a SortDispatch for regularity. I just don't see why I'd need to overload in any way but the default
36+
def Some(x: smt.ExprRef) -> smt.DatatypeRef:
37+
"""
38+
Helper to create Option values
39+
>>> Some(42)
40+
Some(42)
41+
>>> Some(42).sort()
42+
Option_Int
43+
"""
44+
x = smt._py2expr(x)
45+
return Option(x.sort()).Some(x)
46+
47+
48+
def None_(T: smt.SortRef) -> smt.DatatypeRef:
49+
"""
50+
Helper to create Option None_ values
51+
>>> None_(smt.IntSort())
52+
None_
53+
"""
54+
return Option(T).None_

kdrag/theories/seq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ def induct(T: smt.SortRef, P) -> kd.kernel.Proof:
2020
)
2121

2222

23+
def seq(*args):
24+
"""
25+
Helper to construct sequences.
26+
>>> seq(1, 2, 3)
27+
Concat(Unit(1), Concat(Unit(2), Unit(3)))
28+
"""
29+
return smt.Concat(*[smt.Unit(smt._py2expr(a)) for a in args])
30+
31+
2332
class Seq:
2433
def __init__(self, T):
2534
self.T = T

tests/test_notebooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import nbformat
33
from nbclient import NotebookClient
4+
import subprocess
45

56

67
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)