|
| 1 | +(* Copyright (c) INRIA and Microsoft Corporation. All rights reserved. *) |
| 2 | +(* Licensed under the Apache 2.0 and MIT Licenses. *) |
| 3 | + |
| 4 | +(** Transform functions with early returns to use goto. When activated via |
| 5 | + -goto_for_early_return, any function whose body contains a return statement |
| 6 | + in non-tail position is rewritten so that: |
| 7 | + - a return variable is allocated at the top (for non-void functions), |
| 8 | + - every return is replaced by an assignment + goto, |
| 9 | + - a label and final return are appended at the end. *) |
| 10 | + |
| 11 | +module C = C11 |
| 12 | + |
| 13 | +(* Collect all identifier names appearing in a statement tree (variable names, |
| 14 | + declaration names, etc.) so we can pick fresh names that don't collide. *) |
| 15 | +let collect_names (s: C.stmt): (string, unit) Hashtbl.t = |
| 16 | + let names = Hashtbl.create 32 in |
| 17 | + let add n = Hashtbl.replace names n () in |
| 18 | + let rec collect_expr (e: C.expr) = |
| 19 | + match e with |
| 20 | + | C.Name n -> add n |
| 21 | + | C.Op1 (_, e) | C.Deref e | C.Address e | C.Cast (_, e) |
| 22 | + | C.Sizeof e | C.MemberAccess (e, _) | C.MemberAccessPointer (e, _) |
| 23 | + | C.InlineComment (_, e, _) -> collect_expr e |
| 24 | + | C.Op2 (_, e1, e2) | C.Index (e1, e2) | C.Member (e1, e2) |
| 25 | + | C.MemberP (e1, e2) | C.Assign (e1, e2) -> |
| 26 | + collect_expr e1; collect_expr e2 |
| 27 | + | C.Call (e, es) -> collect_expr e; List.iter collect_expr es |
| 28 | + | C.CompoundLiteral (_, inits) -> List.iter collect_init inits |
| 29 | + | C.Stmt ss -> List.iter collect_stmt ss |
| 30 | + | C.CxxInitializerList init -> collect_init init |
| 31 | + | C.Constant _ | C.Literal _ | C.Bool _ | C.Type _ -> () |
| 32 | + and collect_init (i: C.init) = |
| 33 | + match i with |
| 34 | + | C.InitExpr e -> collect_expr e |
| 35 | + | C.Designated (_, i) -> collect_init i |
| 36 | + | C.Initializer is -> List.iter collect_init is |
| 37 | + and collect_decl_ident (d: C.declarator) = |
| 38 | + match d with |
| 39 | + | C.Ident n -> add n |
| 40 | + | C.Pointer (_, d) | C.Array (_, d, _) | C.Function (_, d, _) -> |
| 41 | + collect_decl_ident d |
| 42 | + and collect_stmt (s: C.stmt) = |
| 43 | + match s with |
| 44 | + | C.Compound ss -> List.iter collect_stmt ss |
| 45 | + | C.Decl (_, _, _, _, _, dais) -> |
| 46 | + List.iter (fun (d, _, init) -> |
| 47 | + collect_decl_ident d; |
| 48 | + Option.iter collect_init init) dais |
| 49 | + | C.Expr e -> collect_expr e |
| 50 | + | C.If (e, s) -> collect_expr e; collect_stmt s |
| 51 | + | C.IfElse (e, s1, s2) -> collect_expr e; collect_stmt s1; collect_stmt s2 |
| 52 | + | C.While (e, s) -> collect_expr e; collect_stmt s |
| 53 | + | C.For (doe, e1, e2, s) -> |
| 54 | + (match doe with `Decl d -> collect_stmt (C.Decl d) |
| 55 | + | `Expr e -> collect_expr e | `Skip -> ()); |
| 56 | + collect_expr e1; collect_expr e2; collect_stmt s |
| 57 | + | C.Return (Some e) -> collect_expr e |
| 58 | + | C.Return None | C.Break | C.Continue | C.Comment _ -> () |
| 59 | + | C.Switch (e, branches, default) -> |
| 60 | + collect_expr e; |
| 61 | + List.iter (fun (e, s) -> collect_expr e; collect_stmt s) branches; |
| 62 | + collect_stmt default |
| 63 | + | C.IfDef (_, ss1, elif_blocks, ss2) -> |
| 64 | + List.iter collect_stmt ss1; |
| 65 | + List.iter (fun (_, ss) -> List.iter collect_stmt ss) elif_blocks; |
| 66 | + List.iter collect_stmt ss2 |
| 67 | + | C.Goto _ | C.Label _ -> () |
| 68 | + in |
| 69 | + collect_stmt s; |
| 70 | + names |
| 71 | + |
| 72 | +(* Count the total number of Return statements in a statement tree. *) |
| 73 | +let rec count_returns (s: C.stmt): int = |
| 74 | + match s with |
| 75 | + | C.Return _ -> 1 |
| 76 | + | C.Compound stmts -> List.fold_left (fun acc s -> acc + count_returns s) 0 stmts |
| 77 | + | C.If (_, s) -> count_returns s |
| 78 | + | C.IfElse (_, s1, s2) -> count_returns s1 + count_returns s2 |
| 79 | + | C.While (_, s) -> count_returns s |
| 80 | + | C.For (_, _, _, s) -> count_returns s |
| 81 | + | C.Switch (_, branches, default) -> |
| 82 | + List.fold_left (fun acc (_, s) -> acc + count_returns s) (count_returns default) branches |
| 83 | + | C.IfDef (_, ss1, elif_blocks, ss2) -> |
| 84 | + List.fold_left (fun acc s -> acc + count_returns s) 0 ss1 + |
| 85 | + List.fold_left (fun acc (_, ss) -> |
| 86 | + List.fold_left (fun a s -> a + count_returns s) acc ss) 0 elif_blocks + |
| 87 | + List.fold_left (fun acc s -> acc + count_returns s) 0 ss2 |
| 88 | + | _ -> 0 |
| 89 | + |
| 90 | +(* A function has an early return if: |
| 91 | + - it has multiple returns (at least one must be early), or |
| 92 | + - it has a single return that is not the last statement of the outermost |
| 93 | + compound. *) |
| 94 | +let has_early_return (body: C.stmt): bool = |
| 95 | + match body with |
| 96 | + | C.Compound stmts -> |
| 97 | + let n = count_returns body in |
| 98 | + if n >= 2 then |
| 99 | + true |
| 100 | + else if n = 1 then |
| 101 | + begin match List.rev stmts with |
| 102 | + | C.Return _ :: _ -> false |
| 103 | + | _ -> true |
| 104 | + end |
| 105 | + else |
| 106 | + false |
| 107 | + | _ -> false |
| 108 | + |
| 109 | +(* Replace every Return in a statement tree with an assignment + goto. *) |
| 110 | +let rec rewrite_stmt (ret_var: string) (label: string) (is_void: bool) (s: C.stmt): C.stmt = |
| 111 | + match s with |
| 112 | + | C.Return (Some e) when not is_void -> |
| 113 | + C.Compound [C.Expr (C.Assign (C.Name ret_var, e)); C.Goto label] |
| 114 | + | C.Return None when is_void -> |
| 115 | + C.Goto label |
| 116 | + | C.Return None -> |
| 117 | + (* non-void function with bare return; just goto *) |
| 118 | + C.Goto label |
| 119 | + | C.Return (Some e) -> |
| 120 | + (* void function returning an expression (shouldn't happen, but be safe) *) |
| 121 | + C.Compound [C.Expr e; C.Goto label] |
| 122 | + | C.Compound stmts -> |
| 123 | + C.Compound (List.map (rewrite_stmt ret_var label is_void) stmts) |
| 124 | + | C.If (e, s) -> |
| 125 | + C.If (e, rewrite_stmt ret_var label is_void s) |
| 126 | + | C.IfElse (e, s1, s2) -> |
| 127 | + C.IfElse (e, |
| 128 | + rewrite_stmt ret_var label is_void s1, |
| 129 | + rewrite_stmt ret_var label is_void s2) |
| 130 | + | C.While (e, s) -> |
| 131 | + C.While (e, rewrite_stmt ret_var label is_void s) |
| 132 | + | C.For (d, e1, e2, s) -> |
| 133 | + C.For (d, e1, e2, rewrite_stmt ret_var label is_void s) |
| 134 | + | C.Switch (e, branches, default) -> |
| 135 | + C.Switch (e, |
| 136 | + List.map (fun (e, s) -> (e, rewrite_stmt ret_var label is_void s)) branches, |
| 137 | + rewrite_stmt ret_var label is_void default) |
| 138 | + | C.IfDef (e, ss1, elif_blocks, ss2) -> |
| 139 | + C.IfDef (e, |
| 140 | + List.map (rewrite_stmt ret_var label is_void) ss1, |
| 141 | + List.map (fun (e, ss) -> (e, List.map (rewrite_stmt ret_var label is_void) ss)) elif_blocks, |
| 142 | + List.map (rewrite_stmt ret_var label is_void) ss2) |
| 143 | + | s -> s |
| 144 | + |
| 145 | +(* Extract the inner declarator from a function declarator (stripping |
| 146 | + the Function wrapper) to get the return-type portion. *) |
| 147 | +let extract_return_declarator (d: C.declarator): C.declarator = |
| 148 | + match d with |
| 149 | + | C.Function (_, inner, _) -> inner |
| 150 | + | _ -> d |
| 151 | + |
| 152 | +(* Replace the leaf identifier in a declarator with a new name. *) |
| 153 | +let rec replace_ident (new_name: string) (d: C.declarator): C.declarator = |
| 154 | + match d with |
| 155 | + | C.Ident _ -> C.Ident new_name |
| 156 | + | C.Pointer (qs, d) -> C.Pointer (qs, replace_ident new_name d) |
| 157 | + | C.Array (qs, d, e) -> C.Array (qs, replace_ident new_name d, e) |
| 158 | + | C.Function (cc, d, ps) -> C.Function (cc, replace_ident new_name d, ps) |
| 159 | + |
| 160 | +(* Check if a function declaration has void return type. *) |
| 161 | +let is_void_return (spec: C.type_spec) (decl_and_inits: C.declarator_and_inits): bool = |
| 162 | + match spec with |
| 163 | + | C.Void -> |
| 164 | + begin match decl_and_inits with |
| 165 | + | [(C.Function (_, C.Ident _, _), _, _)] -> true |
| 166 | + | _ -> false |
| 167 | + end |
| 168 | + | _ -> false |
| 169 | + |
| 170 | +(* Generate a type-specific zero initializer based on the return type. |
| 171 | + For pointers: NULL; for integers: 0 with appropriate suffix; for |
| 172 | + bool: false; for named/aggregate types: { 0 }. *) |
| 173 | +let zero_initializer (spec: C.type_spec) (ret_decl: C.declarator): C.init = |
| 174 | + let rec is_pointer = function |
| 175 | + | C.Pointer _ -> true |
| 176 | + | C.Ident _ | C.Array _ | C.Function _ -> false |
| 177 | + in |
| 178 | + if is_pointer ret_decl then |
| 179 | + C.InitExpr (C.Name "NULL") |
| 180 | + else |
| 181 | + match spec with |
| 182 | + | C.Int Constant.Bool -> C.InitExpr (C.Bool false) |
| 183 | + | C.Int w -> C.InitExpr (C.Constant (w, "0")) |
| 184 | + (* KaRaMeL maps Bool to Named "bool" / Named "BOOLEAN" (microsoft) *) |
| 185 | + | C.Named ("bool" | "BOOLEAN") -> C.InitExpr (C.Bool false) |
| 186 | + | C.Void | C.Named _ | C.Struct _ | C.Union _ | C.Enum _ -> |
| 187 | + C.Initializer [C.InitExpr (C.Constant (Constant.UInt8, "0"))] |
| 188 | + |
| 189 | +let no_extra: C.extra = { maybe_unused = false; target = None } |
| 190 | + |
| 191 | +(* Rewrite a single declaration_or_function if it has early returns. *) |
| 192 | +let rewrite_decl (d: C.declaration_or_function): C.declaration_or_function = |
| 193 | + match d with |
| 194 | + | C.Function (comments, |
| 195 | + ((qs, spec, _inline, _stor, _extra, decl_and_inits) as decl), |
| 196 | + body) when has_early_return body -> |
| 197 | + let is_void = is_void_return spec decl_and_inits in |
| 198 | + (* Pick fresh names that don't collide with any identifier in the body. *) |
| 199 | + let used = collect_names body in |
| 200 | + let is_used n = Hashtbl.mem used n in |
| 201 | + let ret_var = Idents.mk_fresh "result" is_used in |
| 202 | + let label = Idents.mk_fresh "exit" is_used in |
| 203 | + let body_stmts = match body with |
| 204 | + | C.Compound stmts -> stmts |
| 205 | + | s -> [s] |
| 206 | + in |
| 207 | + let rewritten_stmts = List.map (rewrite_stmt ret_var label is_void) body_stmts in |
| 208 | + let new_body = |
| 209 | + if is_void then |
| 210 | + (* void function: no return variable needed *) |
| 211 | + C.Compound [ |
| 212 | + C.Compound rewritten_stmts; |
| 213 | + C.Label label; |
| 214 | + C.Return None |
| 215 | + ] |
| 216 | + else |
| 217 | + (* non-void: declare return variable, wrap body, add label + return *) |
| 218 | + let ret_decl_inner = extract_return_declarator (match decl_and_inits with |
| 219 | + | [(d, _, _)] -> d |
| 220 | + | _ -> failwith "unexpected declarator_and_inits") in |
| 221 | + let ret_decl = replace_ident ret_var ret_decl_inner in |
| 222 | + let init = zero_initializer spec ret_decl_inner in |
| 223 | + C.Compound [ |
| 224 | + C.Decl (qs, spec, None, None, no_extra, [ret_decl, None, Some init]); |
| 225 | + C.Compound rewritten_stmts; |
| 226 | + C.Label label; |
| 227 | + C.Return (Some (C.Name ret_var)) |
| 228 | + ] |
| 229 | + in |
| 230 | + C.Function (comments, decl, new_body) |
| 231 | + | d -> d |
| 232 | + |
| 233 | +let rewrite_file (decls: C.declaration_or_function list): C.declaration_or_function list = |
| 234 | + List.map rewrite_decl decls |
0 commit comments