@@ -19,8 +19,9 @@ module Position = struct
1919 | `If
2020 | `While
2121 | `Assign of lvmatch
22- | `Sample
23- | `Call
22+ | `Sample of lvmatch
23+ | `Call of lvmatch
24+ | `Match
2425 ]
2526
2627 and lvmatch = [ `LvmNone | `LvmVar of EcTypes. prog_var ]
@@ -30,9 +31,10 @@ module Position = struct
3031 | `ByMatch of int option * cp_match
3132 ]
3233
33- type codepos1 = int * cp_base
34- type codepos = (codepos1 * int ) list * codepos1
35- type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1 ]
34+ type codepos_brsel = [`Cond of bool | `Match of EcSymbols .symbol ]
35+ type codepos1 = int * cp_base
36+ type codepos = (codepos1 * codepos_brsel ) list * codepos1
37+ type codeoffset1 = [`ByOffset of int | `ByPosition of codepos1 ]
3638
3739 let shift1 ~(offset : int ) ((o , p ) : codepos1 ) : codepos1 =
3840 (o + offset, p)
@@ -57,12 +59,19 @@ module Zipper = struct
5759 type ('a, 'state) folder =
5860 'a -> 'state -> instr -> 'state * instr list
5961
62+ type spath_match_ctxt = {
63+ locals : (EcIdent .t * ty ) list ;
64+ prebr : ((EcIdent .t * ty ) list * stmt ) list ;
65+ postbr : ((EcIdent .t * ty ) list * stmt ) list ;
66+ }
67+
6068 type ipath =
6169 | ZTop
6270 | ZWhile of expr * spath
6371 | ZIfThen of expr * spath * stmt
6472 | ZIfElse of expr * stmt * spath
65-
73+ | ZMatch of expr * spath * spath_match_ctxt
74+
6675 and spath = (instr list * instr list ) * ipath
6776
6877 type zipper = {
@@ -95,9 +104,12 @@ module Zipper = struct
95104 match ir.i_node, cm with
96105 | Swhile _ , `While -> i-1
97106 | Sif _ , `If -> i-1
98- | Srnd _ , `Sample -> i-1
99- | Scall _ , `Call -> i-1
107+ | Smatch _ , `Match -> i-1
108+
109+ | Scall (None, _ , _ ), `Call `LvmNone -> i-1
100110
111+ | Scall (Some lv, _, _), `Call lvm
112+ | Srnd (lv, _), `Sample lvm
101113 | Sasgn (lv , _ ), `Assign lvm -> begin
102114 match lv, lvm with
103115 | _ , `LvmNone -> i-1
@@ -178,23 +190,34 @@ module Zipper = struct
178190
179191 let zipper_at_nm_cpos1
180192 (env : EcEnv.env )
181- ((cp1 , sub ) : codepos1 * int )
193+ ((cp1 , sub ) : codepos1 * codepos_brsel )
182194 (s : stmt )
183195 (zpr : ipath )
184- : (ipath * stmt ) * (codepos1 * int )
196+ : (ipath * stmt ) * (codepos1 * codepos_brsel )
185197 =
186198 let (s1, i, s2) = find_by_cpos1 env cp1 s in
187199 let zpr =
188200 match i.i_node, sub with
189- | Swhile (e , sw ), 0 ->
201+ | Swhile (e , sw ), `Cond true ->
190202 (ZWhile (e, ((s1, s2), zpr)), sw)
191203
192- | Sif (e , ifs1 , ifs2 ), 0 ->
204+ | Sif (e , ifs1 , ifs2 ), `Cond true ->
193205 (ZIfThen (e, ((s1, s2), zpr), ifs2), ifs1)
194206
195- | Sif (e , ifs1 , ifs2 ), 1 ->
207+ | Sif (e , ifs1 , ifs2 ), `Cond false ->
196208 (ZIfElse (e, ifs1, ((s1, s2), zpr)), ifs2)
197209
210+ | Smatch (e , bs ), `Match cn ->
211+ let _, indt, _ = oget (EcEnv.Ty. get_top_decl e.e_ty env) in
212+ let indt = oget (EcDecl. tydecl_as_datatype indt) in
213+ let cnames = List. fst indt.tydt_ctors in
214+ let ix, _ =
215+ try List. findi (fun _ n -> EcSymbols. sym_equal cn n) cnames
216+ with Not_found -> raise InvalidCPos
217+ in
218+ let prebr, (locals, body), postbr = List. pivot_at ix bs in
219+ (ZMatch (e, ((s1, s2), zpr), { locals; prebr; postbr; }), body)
220+
198221 | _ -> raise InvalidCPos
199222 in zpr, ((0 , `ByPos (1 + List. length s1)), sub)
200223
@@ -228,6 +251,8 @@ module Zipper = struct
228251 | ZWhile (e , sp ) -> zip (Some (i_while (e, s))) sp
229252 | ZIfThen (e , sp , se ) -> zip (Some (i_if (e, s, se))) sp
230253 | ZIfElse (e , se , sp ) -> zip (Some (i_if (e, se, s))) sp
254+ | ZMatch (e , sp , mpi ) ->
255+ zip (Some (i_match (e, mpi.prebr @ (mpi.locals, s) :: mpi.postbr))) sp
231256
232257 let zip zpr = zip None ((zpr.z_head, zpr.z_tail), zpr.z_path)
233258
@@ -238,6 +263,7 @@ module Zipper = struct
238263 | ZWhile (_ , ((_ , is ), ip )) -> doit (is :: acc) ip
239264 | ZIfThen (_ , ((_ , is ), ip ), _ ) -> doit (is :: acc) ip
240265 | ZIfElse (_ , _ , ((_ , is ), ip )) -> doit (is :: acc) ip
266+ | ZMatch (_ , ((_ , is ), ip ), _ ) -> doit (is :: acc) ip
241267 in
242268
243269 let after =
@@ -1298,6 +1324,10 @@ module RegexpBaseInstr = struct
12981324 let z' = zipper head tail path in
12991325 next_zipper z'
13001326
1327+ | ZMatch (_ , ((head , tail ), path ), _ ) ->
1328+ let z' = zipper head tail path in
1329+ next_zipper z'
1330+
13011331 let next (e : engine ) =
13021332 next_zipper e.e_zipper |> omap (fun z ->
13031333 { e with e_zipper = z; e_pos = List. length z.z_head })
0 commit comments