Skip to content

feat: add ToCBOR deriving #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions KLR/Serde/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ instance : FromCBOR Bool where
if roundtrip false == false then
throwError "CBOR Bool mismatch: false"

-- ... or we could also just prove the theorem
-- With some automation, may be able to prove all these theorems
theorem rt_bool (b : Bool) :
.ok b = fromCBOR (toCBOR b) := by
unfold toCBOR fromCBOR
unfold parse
unfold instToCBORBool instFromCBORBool
unfold ByteArray.size Except.map Prod.snd
unfold getElem
unfold ByteArray.instGetElemNatUInt8LtSize
unfold ByteArray.get
unfold pure Applicative.toPure Monad.toApplicative
unfold Except.instMonad
unfold Except.pure
induction b <;> simp
done

/-
# 8-bit integers
-/
Expand Down Expand Up @@ -596,3 +613,39 @@ warning: declaration uses 'sorry'
#guard_msgs in
example (x : List (Bool × List UInt8)) :
roundtrip x == true := by plausible

/-
# Encoding of arbitrary inductive types

Each inductive type (including structures) have a serde tag assigned to them,
and each constructor also has a tag. We use the CBOR "tagged value" encoding
for this. For each type, we use a 16-bit tag, the first half is the type tag
and the second half is the constructor tag. Following the tag we have a tuple
of the constructor arguments.

The functions below are called by the derived instances to build and parse the
tagged encoding. For simplicity we limit constructors to at most 24 arguments;
this constraint can be lifted by modifying these functions.
TODO: lift this constraint.
-/

def cborTag (typeTag valTag len : Nat) : ByteArray :=
assert! len < 0x18
let len := toCBOR len.toUInt64
let len := adjustTag 0x80 len
.mk #[ 0xd9, typeTag.toUInt8, valTag.toUInt8] ++ len

def parseCBORTag (arr : ByteArray) : Err (Nat × Nat × Nat) := do
if h:arr.size > 4 then
if arr[0] != 0xd9 then
throw "expecting tagged value"
let typeTag := arr[1]
let valTag := arr[2]
let listTag := arr[3] >>> 5
let listSize := arr[3] &&& 0x1f
if listTag != 0x80 then
throw "expecting list after tagged value"
if listSize >= 0x18 then
throw "expecting small list after tagged value"
return (typeTag.toNat, valTag.toNat, listSize.toNat)
throw "expecting tagged value - array too small"
106 changes: 105 additions & 1 deletion KLR/Serde/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,115 @@ Authors: Paul Govereau, Sean McLaughlin
-/
import KLR.Serde.Attr
import KLR.Serde.Basic
import Lean

/-
# ToCBOR and FromCBOR deriving
-/

namespace KLR.Serde
namespace KLR.Serde.Elab
open Lean Parser.Term Meta Elab Command Deriving

-- Generate a absolute path for a name
def rootName : Name -> Name
| .anonymous => .str .anonymous "_root_"
| .str n s => .str (rootName n) s
| .num n i => .num (rootName n) i

-- Remove KLR prefix from a name
def rmKLR : Name -> Name
| .anonymous => .anonymous
| .str n "KLR" => n
| .str n s => .str (rmKLR n) s
| .num n i => .num (rmKLR n) i

-- Create a fully qualified identifier (e.g. _root_.KLR.foo)
def qualIdent (n : Name) (s : String) : Ident :=
mkIdent (.str (rootName n) s)

-- Generate a (function) name in the namespace of `name`
def fnIdent (name : Name) (s : String) : Ident :=
mkIdent (.str name s)

-- Generate a name suitable for an extern (C) symbol
def cIdent (name : Name) (s : String) : Ident :=
let name := rmKLR name
let cname := name.toString.replace "." "_" ++ "_" ++ s
mkIdent cname.toName

-- Make a list of parameter names, e.g.: x0, x1, ...
def makeNames (n : Nat) (s : String) : Array (TSyntax `ident) :=
let names := (List.range n).map fun n => Name.mkStr1 s!"{s}{n}"
let ids := names.map mkIdent
ids.toArray

-- Get constructor parameter names, e.g: C : X -> Y -> Z ===> x0, x1
def getParams (ctor : Name) : TermElabM (Array (TSyntax `ident)) := do
let ci <- getConstInfoCtor ctor
-- skip over implicit arguments
let count <- forallTelescopeReducing ci.type fun xs _ => do
let bis <- xs.mapM fun x => x.fvarId!.getBinderInfo
let bis := bis.filter fun bi => bi.isExplicit
pure bis.size
return makeNames count "x"

-- Get type parameter names, e.g: T a b ===> a0, a1
def getTypeParams (name : Name) : TermElabM (Array (TSyntax `ident)) := do
let tci <- getConstInfoInduct name
return makeNames tci.numParams "a"

private def lit := Syntax.mkNatLit

-- Generate ToCBOR instances for a set of mutually recursive types
def mkToInstances (names : Array Name) : TermElabM (Array Command) := do
let tags <- liftMetaM <| names.mapM serdeTags
let mut cmds := #[]
let mut insts := #[]
for (name, typeTag, constTags) in names.zip tags do
-- Generate match arms for each constructor
let mut arms := #[]
for (c, val) in constTags do
let ps <- getParams c
let arm <- `(matchAltExpr| | $(mkIdent c) $ps* => Id.run do
let arr := cborTag $(lit typeTag) $(lit val) $(lit ps.size)
$[let arr := arr ++ toCBOR $ps]*
pure arr)
arms := arms.push arm

-- Generate local instances for all type constructors
-- They don'thave to be named, only in scope
let ts <- getTypeParams name
let ls := names.foldrM fun n body => `(
let _ : ToCBOR ($(mkIdent n) $ts*) := ⟨$(fnIdent n "toBytes")⟩
$body
)
-- combine local instances with body
let body <- ls (<- `(match x with $arms:matchAlt*))

-- Generate function for current type constructor
let bs <- ts.mapM fun t => `(instBinder| [ToCBOR $t])
let tname <- `( $(mkIdent name) $ts*)
let cmd <- `(
@[export $(cIdent name "toBytes")]
partial def $(qualIdent name "toBytes") $bs:instBinder* (x : $tname) : ByteArray :=
$body:term
)
cmds := cmds.push cmd

-- Generate public instance declaration for current type constructor
let inst <- `(
instance $bs:instBinder* : ToCBOR $tname := ⟨ $(fnIdent name "toBytes") ⟩
)
insts := insts.push inst

-- combine all functions into a mutual block
-- return mutual block followed by public instance declarations
return #[<- `(mutual $cmds* end)] ++ insts

def mkToCBOR (names : Array Name) : CommandElabM Bool := do
let cmds <- liftTermElabM (mkToInstances names)
cmds.forM elabCommand
return true

initialize
registerDerivingHandler ``ToCBOR mkToCBOR
34 changes: 33 additions & 1 deletion KLR/Serde/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau, Sean McLaughlin
-/
import KLR.Serde.Attr
import Lean
import KLR.Serde.Elab

/-
Tests for the serde attribute.
Expand Down Expand Up @@ -116,3 +116,35 @@ first argument of `lean_alloc_ctor` should be for the `Lean.Name` type.
for (n,v) in <- serdeMap ``Lean.Name do
let n := n.toString.replace "." "_"
IO.println s!"#define {n} {v}"

/-
Deriving tests
-/

@[serde tag=7]
inductive Z where
| a : Nat -> Z
| b : Bool -> Z
deriving ToCBOR

#guard (toCBOR (Z.a 0)).data == #[0xd9, 7, 0, 0x81, 0]
#guard (toCBOR (Z.b true)).data == #[0xd9, 7, 1, 0x81, 0xf5]

mutual
@[serde tag=1]
structure X (a : Type u) where
i : a
b : Bool
deriving ToCBOR

@[serde tag=2]
inductive Y (a : Type u) where
| n : Nat -> Y a
| x : X a -> Y a
deriving ToCBOR
end

#guard (toCBOR (X.mk true false)).data == #[0xd9, 1, 0, 0x82, 0xf5, 0xf4]
#guard (toCBOR (Y.n 7 : Y Bool)).data == #[0xd9, 2, 0, 0x81, 7]
#guard (toCBOR (Y.x (X.mk true false))).data ==
#[0xd9, 2, 1, 0x81, 0xd9, 1, 0, 0x82, 0xf5, 0xf4]