Skip to content

Commit b5dfa18

Browse files
committed
[elpi] link DK metas with Elpi UVars
1 parent 519aa92 commit b5dfa18

File tree

3 files changed

+146
-26
lines changed

3 files changed

+146
-26
lines changed

src/core/elpi_lambdapi.ml

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
11
open Elpi.API
22

3+
module Elpi_AUX = struct
4+
let array_map_fold f st a =
5+
let len = Array.length a in
6+
let st = ref st in
7+
let b = Array.make len RawData.mkNil in
8+
for i = 0 to len-1 do
9+
let st', x = f !st a.(i) in
10+
st := st';
11+
b.(i) <- x
12+
done;
13+
!st, b
14+
15+
let list_map_fold f s l =
16+
let f st x = let st, x = f st x in st, x, [] in
17+
let s, l, _ = Utils.map_acc f s l in
18+
s, l
19+
20+
end
21+
322
let sym : Terms.sym Conversion.t = OpaqueData.declare {
423
OpaqueData.name = "symbol";
524
doc = "A symbol";
@@ -17,12 +36,23 @@ let prodc = RawData.Constants.declare_global_symbol "prod"
1736
let abstc = RawData.Constants.declare_global_symbol "abst"
1837
let applc = RawData.Constants.declare_global_symbol "appl"
1938

39+
module M = struct
40+
type t = Terms.meta
41+
let compare m1 m2 = Stdlib.compare m1.Terms.meta_key m2.Terms.meta_key
42+
let pp = Print.pp_meta
43+
let show m = Format.asprintf "%a" pp m
44+
end
45+
module MM = FlexibleData.Map(M)
46+
47+
let metamap : MM.t State.component = MM.uvmap
48+
2049
let embed_term : Terms.term Conversion.embedding = fun ~depth st t ->
2150
let open RawData in
2251
let open Terms in
2352
let gls = ref [] in
2453
let call f ~depth s x = let s, x, g = f ~depth s x in gls := g @ !gls; s, x in
25-
let rec aux ~depth st ctx = function
54+
let rec aux ~depth ctx st t =
55+
match Terms.unfold t with
2656
| Vari v ->
2757
let d = Ctxt.type_of v ctx in
2858
st, mkBound d
@@ -32,70 +62,92 @@ let embed_term : Terms.term Conversion.embedding = fun ~depth st t ->
3262
let st, s = call sym.Conversion.embed ~depth st s in
3363
st, mkApp symbc s []
3464
| Prod (src, tgt) ->
35-
let st, src = aux ~depth st ctx src in
65+
let st, src = aux ~depth ctx st src in
3666
let _,tgt,ctx = Ctxt.unbind ctx depth None tgt in
37-
let st, tgt = aux ~depth:(depth+1) st ctx tgt in
67+
let st, tgt = aux ~depth:(depth+1) ctx st tgt in
3868
st, mkApp prodc src [mkLam tgt]
3969
| Abst (ty, body) ->
40-
let st, ty = aux ~depth st ctx ty in
70+
let st, ty = aux ~depth ctx st ty in
4171
let _,body,ctx = Ctxt.unbind ctx depth None body in
42-
let st, body = aux ~depth:(depth+1) st ctx body in
72+
let st, body = aux ~depth:(depth+1) ctx st body in
4373
st, mkApp prodc ty [mkLam body]
4474
| Appl (hd, arg) ->
45-
let st, hd = aux ~depth st ctx hd in
46-
let st, arg = aux ~depth st ctx arg in
75+
let st, hd = aux ~depth ctx st hd in
76+
let st, arg = aux ~depth ctx st arg in
4777
st, mkApp applc hd [arg]
48-
| Meta _ -> assert false
78+
| Meta (meta,args) ->
79+
let st, flex =
80+
try st, MM.elpi meta (State.get metamap st)
81+
with Not_found ->
82+
let st, flex = FlexibleData.Elpi.make st in
83+
State.update metamap st (MM.add flex meta), flex in
84+
let st, args = Elpi_AUX.array_map_fold (aux ~depth ctx) st args in
85+
st, mkUnifVar flex ~args:(Array.to_list args) st
4986
| Patt _ -> Console.fatal_no_pos "embed_term: Patt not implemented"
5087
| TEnv _ -> Console.fatal_no_pos "embed_term: TEnv not implemented"
5188
| Wild -> Console.fatal_no_pos "embed_term: Wild not implemented"
5289
| TRef _ -> Console.fatal_no_pos "embed_term: TRef not implemented"
5390
| LLet _ -> Console.fatal_no_pos "embed_term: LLet not implemented"
5491
in
55-
let st, t = aux ~depth st [] t in
92+
let st, t = aux ~depth [] st t in
5693
st, t, List.rev !gls
5794

58-
let readback_term : Terms.term Conversion.readback = fun ~depth st t ->
95+
let readback_term_box : Terms.term Bindlib.box Conversion.readback = fun ~depth st t ->
5996
let open RawData in
6097
let open Terms in
6198
let gls = ref [] in
6299
let call f ~depth s x = let s, x, g = f ~depth s x in gls := g @ !gls; s, x in
63-
let rec aux ~depth st ctx t =
100+
let rec aux ~depth ctx st t =
64101
match look ~depth t with
65102
| Const c when c == typec -> st, _Type
66103
| Const c when c == kindc -> st, _Kind
67104
| Const c when c >= 0 ->
68105
begin try
69106
let v = Extra.IntMap.find c ctx in
70107
st, _Vari v
71-
with Not_found -> Utils.type_error "readback_term" end
108+
with Not_found -> Utils.type_error "readback_term: free variable" end
72109
| App(c,s,[]) when c == symbc ->
73110
let st, s = call sym.Conversion.readback ~depth st s in
74111
st, _Symb s
75112
| App(c,ty,[bo]) when c == prodc ->
76-
let st, ty = aux ~depth st ctx ty in
77-
let st, bo = aux_lam ~depth st ctx bo in
113+
let st, ty = aux ~depth ctx st ty in
114+
let st, bo = aux_lam ~depth ctx st bo in
78115
st, _Prod ty bo
79116
| App(c,ty,[bo]) when c == abstc ->
80-
let st, ty = aux ~depth st ctx ty in
81-
let st, bo = aux_lam ~depth st ctx bo in
117+
let st, ty = aux ~depth ctx st ty in
118+
let st, bo = aux_lam ~depth ctx st bo in
82119
st, _Abst ty bo
83120
| App(c,hd,[arg]) when c == applc ->
84-
let st, hd = aux ~depth st ctx hd in
85-
let st, arg = aux ~depth st ctx arg in
121+
let st, hd = aux ~depth ctx st hd in
122+
let st, arg = aux ~depth ctx st arg in
86123
st, _Appl hd arg
124+
| UnifVar(flex, args) ->
125+
let st, meta =
126+
try st, MM.host flex (State.get metamap st)
127+
with Not_found ->
128+
let m1 = fresh_meta (Env.to_prod Env.empty _Type) 0 in
129+
let a = Env.to_prod Env.empty (_Meta m1 [||]) in
130+
let m2 = fresh_meta a 0 in
131+
State.update metamap st (MM.add flex m2), m2
132+
in
133+
let st, args = Elpi_AUX.list_map_fold (aux ~depth ctx) st args in
134+
st, _Meta meta (Array.of_list args)
87135
| _ -> Utils.type_error "readback_term"
88-
and aux_lam ~depth st ctx t =
136+
and aux_lam ~depth ctx st t =
89137
match look ~depth t with
90138
| Lam bo ->
91139
let v = Bindlib.new_var mkfree "x" in
92140
let ctx = Extra.IntMap.add depth v ctx in
93-
let st, bo = aux ~depth:(depth+1) st ctx bo in
141+
let st, bo = aux ~depth:(depth+1) ctx st bo in
94142
st, Bindlib.bind_var v bo
95143
| _ -> Utils.type_error "readback_term"
96144
in
97-
let st, t = aux ~depth st Extra.IntMap.empty t in
98-
st, Bindlib.unbox t, List.rev !gls
145+
let st, t = aux ~depth Extra.IntMap.empty st t in
146+
st, t, List.rev !gls
147+
148+
let readback_term ~depth st t =
149+
let st, t, gls = readback_term_box ~depth st t in
150+
st, Bindlib.unbox t, gls
99151

100152
let term : Terms.term Conversion.t = {
101153
Conversion.ty = Conversion.TyName "term";
@@ -113,6 +165,35 @@ type prod term -> (term -> term) -> term.
113165
embed = embed_term;
114166
}
115167

168+
let readback_mbinder st t =
169+
let open RawData in
170+
let rec aux ~depth nvars t =
171+
match look ~depth t with
172+
| Lam bo -> aux ~depth:(depth+1) (nvars+1) bo
173+
| _ ->
174+
let open Bindlib in
175+
let vs = Array.init nvars (fun i -> new_var Terms.mkfree (Printf.sprintf "x%d" i)) in
176+
let st, t, _ = readback_term_box ~depth st t in
177+
st, unbox (bind_mvar vs t)
178+
in
179+
aux ~depth:0 0 t
180+
181+
182+
let readback_assignments st =
183+
let mmap = State.get metamap st in
184+
MM.fold (fun meta _flex body st ->
185+
match body with
186+
| None -> st
187+
| Some t ->
188+
let open Timed in
189+
match ! (meta.Terms.meta_value) with
190+
| Some _ -> assert false
191+
| None ->
192+
let st, t = readback_mbinder st t in
193+
meta.Terms.meta_value := Some t;
194+
st
195+
) mmap st
196+
116197
let lambdapi_builtin_declarations : BuiltIn.declaration list =
117198
let open BuiltIn in
118199
let open BuiltInPredicate in
@@ -126,6 +207,13 @@ let lambdapi_builtin_declarations : BuiltIn.declaration list =
126207

127208
LPDoc "---- Builtin predicates ----";
128209

210+
MLCode(Pred("lp.sig",
211+
In(sym,"S",
212+
Out(term,"T",
213+
Easy "Gives the type of a symbol")),
214+
(fun s _ ~depth:_ -> !: (Timed.(!) s.Terms.sym_type))),
215+
DocAbove);
216+
129217
MLCode(Pred("lp.term->string",
130218
In(term,"T",
131219
Out(string,"S",
@@ -175,7 +263,13 @@ fun ss file predicate arg ->
175263
if not (Elpi.API.Compile.static_check ~checker:(Elpi.Builtin.default_checker ()) query) then
176264
Console.fatal pos "elpi: type error";
177265
let exe = Elpi.API.Compile.optimize query in
266+
Format.printf "\nelpi: before: %a\n" Print.pp_term arg;
178267
match Execute.once exe with
179-
| Execute.Success _ -> ()
268+
| Execute.Success { Data.state; pp_ctx; constraints; _ } ->
269+
let _ = readback_assignments state in
270+
Format.printf "\nelpi: after: %a\n"
271+
Print.pp_term arg;
272+
Format.printf "elpi: constraints: %a\n"
273+
Pp.(constraints pp_ctx) constraints
180274
| Failure -> Console.fatal_no_pos "elpi: failure"
181275
| NoMoreSteps -> assert false

tests/OK/elpitest.elpi

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,29 @@
11
main T :-
2-
print T,
3-
print {lp.term->string T}.
2+
print "before type inference" T,
3+
of T _,
4+
print "\nafter type inference" T.
5+
6+
7+
pred of i:term, o:term.
8+
9+
% API to access the type of a symbol
10+
of (symb S) T :- lp.sig S T.
11+
12+
% silly rules
13+
of (prod A B) T :-
14+
of A typ,
15+
pi x\ of x A => of (B x) T.
16+
17+
of (appl H A) Ta :-
18+
of H (prod S T),
19+
of A S,
20+
Ta = T A.
21+
22+
% suspension of typing on holes (type constraint)
23+
of (uvar as U) T :-
24+
declare_constraint (of U T) [U].
25+
26+
% uniqueness of typing
27+
constraint of {
28+
rule (of (uvar X _) T) \ (of (uvar X _) S) <=> (S = T).
29+
}

tests/OK/elpitest.lp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ require open tests.OK.logic
22
require open tests.OK.bool
33
require open tests.OK.nat
44

5-
elpi "tests/OK/elpitest.elpi" "main"x y : nat, P (eq nat x y) → P (eq nat y x))
5+
elpi "tests/OK/elpitest.elpi" "main"x y, P (eq ?T[x;y] x y) → P (eq ?T[x;y] y x))

0 commit comments

Comments
 (0)