Skip to content

Commit 9eaff01

Browse files
authored
Improve code position (match + extended assignments) (#654)
- `^match` pattern now allowed within a code pos - `#cname` can be used to select the appropriate sub-branch of a match, e.g. `^match#Some.1` - `^lv<@` and `^lv<$` are now permitted
1 parent ed8f813 commit 9eaff01

File tree

8 files changed

+144
-33
lines changed

8 files changed

+144
-33
lines changed

src/ecMatching.ml

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ module Position = struct
1919
| `If
2020
| `While
2121
| `Assign of lvmatch
22-
| `Sample
23-
| `Call
22+
| `Sample of lvmatch
23+
| `Call of lvmatch
24+
| `Match
2425
]
2526

2627
and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ]
@@ -30,9 +31,10 @@ module Position = struct
3031
| `ByMatch of int option * cp_match
3132
]
3233

33-
type codepos1 = int * cp_base
34-
type codepos = (codepos1 * int) list * codepos1
35-
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
34+
type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol]
35+
type codepos1 = int * cp_base
36+
type codepos = (codepos1 * codepos_brsel) list * codepos1
37+
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
3638

3739
let shift1 ~(offset : int) ((o, p) : codepos1) : codepos1 =
3840
(o + offset, p)
@@ -57,12 +59,19 @@ module Zipper = struct
5759
type ('a, 'state) folder =
5860
'a -> 'state -> instr -> 'state * instr list
5961

62+
type spath_match_ctxt = {
63+
locals : (EcIdent.t * ty) list;
64+
prebr : ((EcIdent.t * ty) list * stmt) list;
65+
postbr : ((EcIdent.t * ty) list * stmt) list;
66+
}
67+
6068
type ipath =
6169
| ZTop
6270
| ZWhile of expr * spath
6371
| ZIfThen of expr * spath * stmt
6472
| ZIfElse of expr * stmt * spath
65-
73+
| ZMatch of expr * spath * spath_match_ctxt
74+
6675
and spath = (instr list * instr list) * ipath
6776

6877
type zipper = {
@@ -95,9 +104,12 @@ module Zipper = struct
95104
match ir.i_node, cm with
96105
| Swhile _, `While -> i-1
97106
| Sif _, `If -> i-1
98-
| Srnd _, `Sample -> i-1
99-
| Scall _, `Call -> i-1
107+
| Smatch _, `Match -> i-1
108+
109+
| Scall (None, _, _), `Call `LvmNone -> i-1
100110

111+
| Scall (Some lv, _, _), `Call lvm
112+
| Srnd (lv, _), `Sample lvm
101113
| Sasgn (lv, _), `Assign lvm -> begin
102114
match lv, lvm with
103115
| _, `LvmNone -> i-1
@@ -178,23 +190,34 @@ module Zipper = struct
178190

179191
let zipper_at_nm_cpos1
180192
(env : EcEnv.env)
181-
((cp1, sub) : codepos1 * int)
193+
((cp1, sub) : codepos1 * codepos_brsel)
182194
(s : stmt)
183195
(zpr : ipath)
184-
: (ipath * stmt) * (codepos1 * int)
196+
: (ipath * stmt) * (codepos1 * codepos_brsel)
185197
=
186198
let (s1, i, s2) = find_by_cpos1 env cp1 s in
187199
let zpr =
188200
match i.i_node, sub with
189-
| Swhile (e, sw), 0 ->
201+
| Swhile (e, sw), `Cond true ->
190202
(ZWhile (e, ((s1, s2), zpr)), sw)
191203

192-
| Sif (e, ifs1, ifs2), 0 ->
204+
| Sif (e, ifs1, ifs2), `Cond true ->
193205
(ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1)
194206

195-
| Sif (e, ifs1, ifs2), 1 ->
207+
| Sif (e, ifs1, ifs2), `Cond false ->
196208
(ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2)
197209

210+
| Smatch (e, bs), `Match cn ->
211+
let _, indt, _ = oget (EcEnv.Ty.get_top_decl e.e_ty env) in
212+
let indt = oget (EcDecl.tydecl_as_datatype indt) in
213+
let cnames = List.fst indt.tydt_ctors in
214+
let ix, _ =
215+
try List.findi (fun _ n -> EcSymbols.sym_equal cn n) cnames
216+
with Not_found -> raise InvalidCPos
217+
in
218+
let prebr, (locals, body), postbr = List.pivot_at ix bs in
219+
(ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body)
220+
198221
| _ -> raise InvalidCPos
199222
in zpr, ((0, `ByPos (1 + List.length s1)), sub)
200223

@@ -228,6 +251,8 @@ module Zipper = struct
228251
| ZWhile (e, sp) -> zip (Some (i_while (e, s))) sp
229252
| ZIfThen (e, sp, se) -> zip (Some (i_if (e, s, se))) sp
230253
| ZIfElse (e, se, sp) -> zip (Some (i_if (e, se, s))) sp
254+
| ZMatch (e, sp, mpi) ->
255+
zip (Some (i_match (e, mpi.prebr @ (mpi.locals, s) :: mpi.postbr))) sp
231256

232257
let zip zpr = zip None ((zpr.z_head, zpr.z_tail), zpr.z_path)
233258

@@ -238,6 +263,7 @@ module Zipper = struct
238263
| ZWhile (_, ((_, is), ip)) -> doit (is :: acc) ip
239264
| ZIfThen (_, ((_, is), ip), _) -> doit (is :: acc) ip
240265
| ZIfElse (_, _, ((_, is), ip)) -> doit (is :: acc) ip
266+
| ZMatch (_, ((_, is), ip), _) -> doit (is :: acc) ip
241267
in
242268

243269
let after =
@@ -1298,6 +1324,10 @@ module RegexpBaseInstr = struct
12981324
let z' = zipper head tail path in
12991325
next_zipper z'
13001326

1327+
| ZMatch (_, ((head, tail), path), _) ->
1328+
let z' = zipper head tail path in
1329+
next_zipper z'
1330+
13011331
let next (e : engine) =
13021332
next_zipper e.e_zipper |> omap (fun z ->
13031333
{ e with e_zipper = z; e_pos = List.length z.z_head })

src/ecMatching.mli

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ module Position : sig
1414
type cp_match = [
1515
| `If
1616
| `While
17+
| `Match
1718
| `Assign of lvmatch
18-
| `Sample
19-
| `Call
19+
| `Sample of lvmatch
20+
| `Call of lvmatch
2021
]
2122

2223
and lvmatch = [ `LvmNone | `LvmVar of EcTypes.prog_var ]
@@ -26,9 +27,10 @@ module Position : sig
2627
| `ByMatch of int option * cp_match
2728
]
2829

29-
type codepos1 = int * cp_base
30-
type codepos = (codepos1 * int) list * codepos1
31-
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
30+
type codepos_brsel = [`Cond of bool | `Match of EcSymbols.symbol]
31+
type codepos1 = int * cp_base
32+
type codepos = (codepos1 * codepos_brsel) list * codepos1
33+
type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1]
3234

3335
val shift1 : offset:int -> codepos1 -> codepos1
3436
val shift : offset:int -> codepos -> codepos
@@ -40,11 +42,18 @@ end
4042
module Zipper : sig
4143
open Position
4244

45+
type spath_match_ctxt = {
46+
locals : (EcIdent.t * ty) list;
47+
prebr : ((EcIdent.t * ty) list * stmt) list;
48+
postbr : ((EcIdent.t * ty) list * stmt) list;
49+
}
50+
4351
type ipath =
4452
| ZTop
4553
| ZWhile of expr * spath
4654
| ZIfThen of expr * spath * stmt
4755
| ZIfElse of expr * stmt * spath
56+
| ZMatch of expr * spath * spath_match_ctxt
4857

4958
and spath = (instr list * instr list) * ipath
5059

src/ecParser.mly

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,9 +2606,10 @@ tac_dir:
26062606
icodepos_r:
26072607
| IF { (`If :> pcp_match) }
26082608
| WHILE { (`While :> pcp_match) }
2609-
| LESAMPLE { (`Sample :> pcp_match) }
2610-
| LEAT { (`Call :> pcp_match) }
2609+
| MATCH { (`Match :> pcp_match) }
26112610

2611+
| lvm=lvmatch LESAMPLE { (`Sample lvm :> pcp_match) }
2612+
| lvm=lvmatch LEAT { (`Call lvm :> pcp_match) }
26122613
| lvm=lvmatch LARROW { (`Assign lvm :> pcp_match) }
26132614

26142615
lvmatch:
@@ -2631,9 +2632,14 @@ codepos1:
26312632
| cp=codepos1_wo_off AMP PLUS i=word { ( i, cp) }
26322633
| cp=codepos1_wo_off AMP MINUS i=word { (-i, cp) }
26332634

2635+
branch_select:
2636+
| SHARP s=boident DOT {`Match s}
2637+
| DOT { `Cond true }
2638+
| QUESTION { `Cond false }
2639+
26342640
%inline nm1_codepos:
2635-
| i=codepos1 k=ID(DOT { 0 } | QUESTION { 1 } )
2636-
{ (i, k) }
2641+
| i=codepos1 bs=branch_select
2642+
{ (i, bs) }
26372643

26382644
codepos:
26392645
| nm=rlist0(nm1_codepos, empty) i=codepos1

src/ecParsetree.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,19 @@ type preduction = {
490490
type pcp_match = [
491491
| `If
492492
| `While
493+
| `Match
493494
| `Assign of plvmatch
494-
| `Sample
495-
| `Call
495+
| `Sample of plvmatch
496+
| `Call of plvmatch
496497
]
497498

498499
and plvmatch = [ `LvmNone | `LvmVar of pqsymbol ]
499500

500501
type pcp_base = [ `ByPos of int | `ByMatch of int option * pcp_match ]
501502

503+
type pbranch_select = [`Cond of bool | `Match of psymbol]
502504
type pcodepos1 = int * pcp_base
503-
type pcodepos = (pcodepos1 * int) list * pcodepos1
505+
type pcodepos = (pcodepos1 * pbranch_select) list * pcodepos1
504506
type pdocodepos1 = pcodepos1 doption option
505507

506508
type pcodeoffset1 = [

src/ecPrinting.ml

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,9 +2124,12 @@ let pp_codepos1 (ppe : PPEnv.t) (fmt : Format.formatter) ((off, cp) : CP.codepos
21242124
let k =
21252125
match k with
21262126
| `If -> "if"
2127+
| `Match -> "match"
21272128
| `While -> "while"
2128-
| `Sample -> "<$"
2129-
| `Call -> "<@"
2129+
| `Sample `LvmNone -> "<$"
2130+
| `Sample (`LvmVar pv) -> Format.asprintf "%a<$" (pp_pv ppe) pv
2131+
| `Call `LvmNone -> "<@"
2132+
| `Call (`LvmVar pv) -> Format.asprintf "%a<@" (pp_pv ppe) pv
21302133
| `Assign `LvmNone -> "<-"
21312134
| `Assign (`LvmVar pv) -> Format.asprintf "%a<-" (pp_pv ppe) pv
21322135
in Format.asprintf "^%s" k in
@@ -2146,14 +2149,20 @@ let pp_codeoffset1 (ppe : PPEnv.t) (fmt : Format.formatter) (offset : CP.codeoff
21462149
match offset with
21472150
| `ByPosition p -> Format.fprintf fmt "%a" (pp_codepos1 ppe) p
21482151
| `ByOffset o -> Format.fprintf fmt "%d" o
2149-
2152+
21502153
(* -------------------------------------------------------------------- *)
21512154
let pp_codepos (ppe : PPEnv.t) (fmt : Format.formatter) ((nm, cp1) : CP.codepos) =
2152-
let pp_nm (fmt : Format.formatter) ((cp, i) : CP.codepos1 * int) =
2153-
Format.eprintf "%a%s" (pp_codepos1 ppe) cp (if i = 0 then "." else "?")
2155+
let pp_nm (fmt : Format.formatter) ((cp, bs) : CP.codepos1 * CP.codepos_brsel) =
2156+
let bs =
2157+
match bs with
2158+
| `Cond true -> "."
2159+
| `Cond false -> "?"
2160+
| `Match cp -> Format.sprintf "#%s." cp
2161+
in
2162+
Format.fprintf fmt "%a%s" (pp_codepos1 ppe) cp bs
21542163
in
21552164

2156-
Format.eprintf "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1
2165+
Format.fprintf fmt "%a%a" (pp_list "" pp_nm) nm (pp_codepos1 ppe) cp1
21572166

21582167
(* -------------------------------------------------------------------- *)
21592168
let pp_opdecl_pr (ppe : PPEnv.t) fmt (basename, ts, ty, op) =

src/ecTyping.ml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3455,22 +3455,33 @@ let trans_lv_match ?(memory : memory option) (env : EcEnv.env) (p : plvmatch) :
34553455
(* -------------------------------------------------------------------- *)
34563456
let trans_cp_match ?(memory : memory option) (env : EcEnv.env) (p : pcp_match) : cp_match =
34573457
match p with
3458-
| (`Sample | `While | `Call | `If) as p ->
3458+
| (`While | `If | `Match) as p ->
34593459
(p :> cp_match)
3460+
| `Sample lv ->
3461+
`Sample (trans_lv_match ?memory env lv)
3462+
| `Call lv ->
3463+
`Call (trans_lv_match ?memory env lv)
34603464
| `Assign lv ->
34613465
`Assign (trans_lv_match ?memory env lv)
34623466
(* -------------------------------------------------------------------- *)
34633467
let trans_cp_base ?(memory : memory option) (env : EcEnv.env) (p : pcp_base) : cp_base =
34643468
match p with
34653469
| `ByPos _ as p -> (p :> cp_base)
34663470
| `ByMatch (i, p) -> `ByMatch (i, trans_cp_match ?memory env p)
3471+
34673472
(* -------------------------------------------------------------------- *)
34683473
let trans_codepos1 ?(memory : memory option) (env : EcEnv.env) (p : pcodepos1) : codepos1 =
34693474
snd_map (trans_cp_base ?memory env) p
34703475

3476+
(* -------------------------------------------------------------------- *)
3477+
let trans_codepos_brsel (bs : pbranch_select) : codepos_brsel =
3478+
match bs with
3479+
| `Cond b -> `Cond b
3480+
| `Match { pl_desc = x } -> `Match x
3481+
34713482
(* -------------------------------------------------------------------- *)
34723483
let trans_codepos ?(memory : memory option) (env : EcEnv.env) ((nm, p) : pcodepos) : codepos =
3473-
let nm = List.map (fst_map (trans_codepos1 ?memory env)) nm in
3484+
let nm = List.map (fun (cp1, bs) -> (trans_codepos1 ?memory env cp1, trans_codepos_brsel bs)) nm in
34743485
let p = trans_codepos1 ?memory env p in
34753486
(nm, p)
34763487

src/phl/ecPhlInline.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ module HiInternal = struct
309309
| Zp.ZWhile (_, sp) -> aux_s (IPwhile aout) sp
310310
| Zp.ZIfThen (_, sp, _) -> aux_s (IPif (aout, [])) sp
311311
| Zp.ZIfElse (_, _, sp) -> aux_s (IPif ([], aout)) sp
312+
| Zp.ZMatch (_, sp, mpi) ->
313+
let prebr = List.map (fun _ -> []) mpi.prebr in
314+
let postbr = List.map (fun _ -> []) mpi.postbr in
315+
aux_s (IPmatch (prebr @ aout :: postbr)) sp
312316

313317
and aux_s aout ((sl, _), ip) =
314318
aux_i [(List.length sl, aout)] ip

tests/match_codepos.ec

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
(* -------------------------------------------------------------------- *)
2+
require import Distr.
3+
4+
(* -------------------------------------------------------------------- *)
5+
module M = {
6+
proc f(x : bool option) = {
7+
var y;
8+
y <- false;
9+
match x with
10+
| None => {}
11+
| Some v => {
12+
if (v) {
13+
y <$ dunit ((y || true) && true);
14+
}
15+
}
16+
end;
17+
return y;
18+
}
19+
proc g(x : bool option) = {
20+
var z;
21+
z <- false;
22+
match x with
23+
| None => {}
24+
| Some v => {
25+
if (v) {
26+
z <$ dunit true;
27+
}
28+
}
29+
end;
30+
return z;
31+
}
32+
}.
33+
34+
(* -------------------------------------------------------------------- *)
35+
equiv l: M.f ~ M.g: ={arg} ==> ={res}.
36+
proof.
37+
proc.
38+
proc rewrite {1} ^match#Some.^if.^y<$ /=.
39+
by sim.
40+
qed.

0 commit comments

Comments
 (0)