|
| 1 | +(* Copyright (c) INRIA and Microsoft Corporation. All rights reserved. *) |
| 2 | +(* Licensed under the Apache 2.0 and MIT Licenses. *) |
| 3 | + |
| 4 | +(** Transform functions with early returns to use goto. When activated via |
| 5 | + -goto_for_early_return, any function whose body contains a return statement |
| 6 | + in non-tail position is rewritten so that: |
| 7 | + - a return variable is allocated at the top (for non-void functions), |
| 8 | + - every return is replaced by an assignment + goto, |
| 9 | + - a label and final return are appended at the end. |
| 10 | +
|
| 11 | + This pass operates on the CStar AST, before lowering to C11. *) |
| 12 | + |
| 13 | +open CStar |
| 14 | + |
| 15 | +(* Collect all identifier names appearing in a block so we can pick fresh |
| 16 | + names that don't collide. We walk expressions and statements collecting |
| 17 | + variable references, declaration names, binder names, etc. *) |
| 18 | +let collect_names (body: block): (string, unit) Hashtbl.t = |
| 19 | + let names = Hashtbl.create 32 in |
| 20 | + let add n = Hashtbl.replace names n () in |
| 21 | + let rec collect_expr (e: expr) = |
| 22 | + match e with |
| 23 | + | Var v -> add v |
| 24 | + | Qualified (_, n) | Macro (_, n) -> add n |
| 25 | + | Constant _ | Bool _ | StringLiteral _ | Any | EAbort _ | Type _ |
| 26 | + | BufNull | Op _ -> () |
| 27 | + | Cast (e, _) | Field (e, _) | AddrOf e | InlineComment (_, e, _) -> |
| 28 | + collect_expr e |
| 29 | + | BufRead (e1, e2) | BufSub (e1, e2) | Comma (e1, e2) |
| 30 | + | BufCreate (_, e1, e2) -> |
| 31 | + collect_expr e1; collect_expr e2 |
| 32 | + | Call (e, es) -> collect_expr e; List.iter collect_expr es |
| 33 | + | BufCreateL (_, es) -> List.iter collect_expr es |
| 34 | + | Struct (_, fes) -> List.iter (fun (_, e) -> collect_expr e) fes |
| 35 | + | Stmt ss -> List.iter collect_stmt ss |
| 36 | + and collect_stmt (s: stmt) = |
| 37 | + match s with |
| 38 | + | Abort _ | Break | Continue | Comment _ | Goto _ | Label _ -> () |
| 39 | + | Return (Some e) -> collect_expr e |
| 40 | + | Return None -> () |
| 41 | + | Ignore e | BufFree (_, e) -> collect_expr e |
| 42 | + | Decl (b, e) -> add b.name; collect_expr e |
| 43 | + | Assign (e1, _, e2) -> collect_expr e1; collect_expr e2 |
| 44 | + | BufWrite (e1, e2, e3) -> |
| 45 | + collect_expr e1; collect_expr e2; collect_expr e3 |
| 46 | + | BufBlit (_, e1, e2, e3, e4, e5) -> |
| 47 | + List.iter collect_expr [e1; e2; e3; e4; e5] |
| 48 | + | BufFill (_, e1, e2, e3) -> |
| 49 | + List.iter collect_expr [e1; e2; e3] |
| 50 | + | IfThenElse (_, e, b1, b2) -> |
| 51 | + collect_expr e; collect_block b1; collect_block b2 |
| 52 | + | While (e, b) -> collect_expr e; collect_block b |
| 53 | + | For (init, e, iter, b) -> |
| 54 | + (match init with |
| 55 | + | `Decl (bi, e') -> add bi.name; collect_expr e' |
| 56 | + | `Stmt s -> collect_stmt s |
| 57 | + | `Skip -> ()); |
| 58 | + collect_expr e; collect_stmt iter; collect_block b |
| 59 | + | Switch (e, branches, default) -> |
| 60 | + collect_expr e; |
| 61 | + List.iter (fun (_, b) -> collect_block b) branches; |
| 62 | + Option.iter collect_block default |
| 63 | + | Block b -> collect_block b |
| 64 | + and collect_block b = List.iter collect_stmt b |
| 65 | + in |
| 66 | + collect_block body; |
| 67 | + names |
| 68 | + |
| 69 | +(* Position-aware early-return detection. A return in "terminal" position |
| 70 | + (i.e., last statement in the function body, or in a branch of an |
| 71 | + if-then-else/switch that is itself in terminal position) is NOT early. |
| 72 | + Any return in non-terminal position IS early. *) |
| 73 | +let has_early_return (body: block): bool = |
| 74 | + let found = ref false in |
| 75 | + (* Walk statements. [terminal] tracks whether the current position could |
| 76 | + be the "last thing" before the function returns naturally. *) |
| 77 | + let rec walk_block ~terminal (stmts: block) = |
| 78 | + match stmts with |
| 79 | + | [] -> () |
| 80 | + | [s] -> walk_stmt ~terminal s |
| 81 | + | s :: rest -> walk_stmt ~terminal:false s; walk_block ~terminal rest |
| 82 | + and walk_stmt ~terminal (s: stmt) = |
| 83 | + match s with |
| 84 | + | Return _ -> |
| 85 | + if not terminal then found := true |
| 86 | + | IfThenElse (_, _, b1, b2) -> |
| 87 | + walk_block ~terminal b1; |
| 88 | + walk_block ~terminal b2 |
| 89 | + | Switch (_, branches, default) -> |
| 90 | + List.iter (fun (_, b) -> walk_block ~terminal b) branches; |
| 91 | + Option.iter (walk_block ~terminal) default |
| 92 | + | While (_, b) -> |
| 93 | + (* Loop body is not terminal — the loop may exit without returning *) |
| 94 | + walk_block ~terminal:false b |
| 95 | + | For (_, _, _, b) -> |
| 96 | + walk_block ~terminal:false b |
| 97 | + | Block b -> |
| 98 | + walk_block ~terminal b |
| 99 | + | Abort _ | Break | Continue | Comment _ | Goto _ | Label _ |
| 100 | + | Ignore _ | Decl _ | Assign _ | BufWrite _ | BufBlit _ | BufFill _ |
| 101 | + | BufFree _ -> |
| 102 | + () |
| 103 | + in |
| 104 | + walk_block ~terminal:true body; |
| 105 | + !found |
| 106 | + |
| 107 | +(* Generate a type-specific zero initializer expression for the return |
| 108 | + variable. *) |
| 109 | +let zero_initializer (t: typ): expr = |
| 110 | + match t with |
| 111 | + | Int w -> Constant (w, "0") |
| 112 | + | Bool -> Bool false |
| 113 | + | Pointer _ -> BufNull |
| 114 | + | Void -> failwith "zero_initializer called on Void" |
| 115 | + | Qualified _ | Struct _ | Enum _ | Union _ | Array _ | Function _ |
| 116 | + | Const _ -> |
| 117 | + (* For aggregate/named types, produce a struct literal with a single |
| 118 | + zero field. CStarToC11 translates this to { 0U }. *) |
| 119 | + Struct (None, [(None, Constant (Constant.UInt8, "0"))]) |
| 120 | + |
| 121 | +(* Rewrite a block, replacing Return statements with assignment + goto. *) |
| 122 | +let rewrite_block (ret_var: ident) (ret_typ: typ) (label: ident) (is_void: bool) (body: block): block = |
| 123 | + let rec rewrite_stmts (stmts: block): block = |
| 124 | + List.concat_map rewrite_one stmts |
| 125 | + and rewrite_one (s: stmt): block = |
| 126 | + match s with |
| 127 | + | Return (Some e) when not is_void -> |
| 128 | + [Assign (Var ret_var, ret_typ, e); Goto label] |
| 129 | + | Return (Some _e) when is_void -> |
| 130 | + Warn.fatal_error "void function returning a value" |
| 131 | + | Return None -> |
| 132 | + [Goto label] |
| 133 | + | IfThenElse (ifdef, e, b1, b2) -> |
| 134 | + [IfThenElse (ifdef, e, rewrite_stmts b1, rewrite_stmts b2)] |
| 135 | + | Switch (e, branches, default) -> |
| 136 | + [Switch (e, |
| 137 | + List.map (fun (c, b) -> (c, rewrite_stmts b)) branches, |
| 138 | + Option.map rewrite_stmts default)] |
| 139 | + | While (e, b) -> |
| 140 | + [While (e, rewrite_stmts b)] |
| 141 | + | For (init, e, iter, b) -> |
| 142 | + [For (init, e, iter, rewrite_stmts b)] |
| 143 | + | Block b -> |
| 144 | + [Block (rewrite_stmts b)] |
| 145 | + | s -> [s] |
| 146 | + in |
| 147 | + rewrite_stmts body |
| 148 | + |
| 149 | +(* Rewrite a single CStar declaration if it has early returns. *) |
| 150 | +let rewrite_decl (d: decl): decl = |
| 151 | + match d with |
| 152 | + | Function (cc, flags, ret_typ, lid, binders, body) when has_early_return body -> |
| 153 | + let is_void = (ret_typ = Void) in |
| 154 | + let used = collect_names body in |
| 155 | + let is_used n = Hashtbl.mem used n in |
| 156 | + let ret_var = Idents.mk_fresh "result" is_used in |
| 157 | + let label = Idents.mk_fresh "exit" is_used in |
| 158 | + let rewritten = rewrite_block ret_var ret_typ label is_void body in |
| 159 | + let new_body = |
| 160 | + if is_void then |
| 161 | + rewritten @ [Label label] |
| 162 | + else |
| 163 | + let init = zero_initializer ret_typ in |
| 164 | + [Decl ({ name = ret_var; typ = ret_typ }, init)] |
| 165 | + @ rewritten |
| 166 | + @ [Label label; Return (Some (Var ret_var))] |
| 167 | + in |
| 168 | + Function (cc, flags, ret_typ, lid, binders, new_body) |
| 169 | + | d -> d |
| 170 | + |
| 171 | +let rewrite_file (decls: decl list): decl list = |
| 172 | + List.map rewrite_decl decls |
0 commit comments