Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 119 additions & 64 deletions src/smtml/expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,21 @@ let rec ty (hte : t) : Ty.t =
| Symbol x -> Symbol.type_of x
| List _ -> Ty_list
| App (sym, _) -> begin match sym.ty with Ty_none -> Ty_app | ty -> ty end
| Unop (ty, _, _) -> ty
| Binop (ty, _, _, _) -> ty
| Triop (_, Ite, _, hte1, hte2) ->
let ty1 = ty hte1 in
let ty2 = ty hte2 in
assert (Ty.equal ty1 ty2);
assert (
let ty2 = ty hte2 in
Ty.equal ty1 ty2 );
ty1
| Triop (ty, _, _, _, _) -> ty
| Relop (ty, _, _, _) -> ty
| Cvtop (_, (Zero_extend m | Sign_extend m), hte) -> (
match ty hte with Ty_bitv n -> Ty_bitv (n + m) | _ -> assert false )
| Cvtop (ty, _, _) -> ty
| Naryop (ty, _, _) -> ty
| Unop (ty, _, _)
| Binop (ty, _, _, _)
| Triop (ty, _, _, _, _)
| Relop (ty, _, _, _)
| Cvtop (ty, _, _)
| Naryop (ty, _, _) ->
ty
| Extract (_, h, l) -> Ty_bitv ((h - l) * 8)
| Concat (e1, e2) -> (
match (ty e1, ty e2) with
Expand All @@ -138,22 +140,16 @@ let rec ty (hte : t) : Ty.t =

let rec is_symbolic (v : t) : bool =
match view v with
| Val _ -> false
| Val _ | Loc _ -> false
| Symbol _ -> true
| Loc _ -> false
| Ptr { offset; _ } -> is_symbolic offset
| List vs -> List.exists is_symbolic vs
| App (_, vs) -> List.exists is_symbolic vs
| Unop (_, _, v) -> is_symbolic v
| Binop (_, _, v1, v2) -> is_symbolic v1 || is_symbolic v2
| Unop (_, _, v) | Cvtop (_, _, v) | Extract (v, _, _) | Binder (_, _, v) ->
is_symbolic v
| Binop (_, _, v1, v2) | Relop (_, _, v1, v2) | Concat (v1, v2) ->
is_symbolic v1 || is_symbolic v2
| Triop (_, _, v1, v2, v3) ->
is_symbolic v1 || is_symbolic v2 || is_symbolic v3
| Cvtop (_, _, v) -> is_symbolic v
| Relop (_, _, v1, v2) -> is_symbolic v1 || is_symbolic v2
| Naryop (_, _, vs) -> List.exists is_symbolic vs
| Extract (e, _, _) -> is_symbolic e
| Concat (e1, e2) -> is_symbolic e1 || is_symbolic e2
| Binder (_, _, e) -> is_symbolic e
| List vs | App (_, vs) | Naryop (_, _, vs) -> List.exists is_symbolic vs

let get_symbols (hte : t list) =
let tbl = Hashtbl.create 64 in
Expand Down Expand Up @@ -245,50 +241,6 @@ let pp_smtml fmt (es : t list) : unit =

let to_string e = Fmt.str "%a" pp e

module Set = struct
include PatriciaTree.MakeHashconsedSet (Key) ()

let hash = to_int

let pp fmt v =
Fmt.pf fmt "@[<hov 1>%a@]"
(pretty ~pp_sep:(fun fmt () -> Fmt.pf fmt "@;") pp)
v

let get_symbols (set : t) =
let tbl = Hashtbl.create 64 in
let rec symbols hte =
match view hte with
| Val _ | Loc _ -> ()
| Ptr { offset; _ } -> symbols offset
| Symbol s -> Hashtbl.replace tbl s ()
| List es -> List.iter symbols es
| App (_, es) -> List.iter symbols es
| Unop (_, _, e1) -> symbols e1
| Binop (_, _, e1, e2) ->
symbols e1;
symbols e2
| Triop (_, _, e1, e2, e3) ->
symbols e1;
symbols e2;
symbols e3
| Relop (_, _, e1, e2) ->
symbols e1;
symbols e2
| Cvtop (_, _, e) -> symbols e
| Naryop (_, _, es) -> List.iter symbols es
| Extract (e, _, _) -> symbols e
| Concat (e1, e2) ->
symbols e1;
symbols e2
| Binder (_, vars, e) ->
List.iter symbols vars;
symbols e
in
iter symbols set;
Hashtbl.fold (fun k () acc -> k :: acc) tbl []
end

let value (v : Value.t) : t = make (Val v) [@@inline]

let ptr base offset = make (Ptr { base = Bitvector.of_int32 base; offset })
Expand Down Expand Up @@ -819,3 +771,106 @@ module Smtlib = struct
| Concat _ -> assert false
| Binder _ -> assert false
end

let inline_symbol_values map e =
let rec aux e =
match view e with
| Val _ | Loc _ -> e
| Symbol symbol -> Option.value ~default:e (Symbol.Map.find_opt symbol map)
| Ptr e ->
let offset = aux e.offset in
make @@ Ptr { e with offset }
| List vs ->
let vs = List.map aux vs in
list vs
| App (x, vs) ->
let vs = List.map aux vs in
app x vs
| Unop (ty, op, v) ->
let v = aux v in
unop ty op v
| Binop (ty, op, v1, v2) ->
let v1 = aux v1 in
let v2 = aux v2 in
binop ty op v1 v2
| Triop (ty, op, v1, v2, v3) ->
let v1 = aux v1 in
let v2 = aux v2 in
let v3 = aux v3 in
triop ty op v1 v2 v3
| Cvtop (ty, op, v) ->
let v = aux v in
cvtop ty op v
| Relop (ty, op, v1, v2) ->
let v1 = aux v1 in
let v2 = aux v2 in
relop ty op v1 v2
| Naryop (ty, op, vs) ->
let vs = List.map aux vs in
naryop ty op vs
| Extract (e, high, low) ->
let e = aux e in
extract e ~high ~low
| Concat (e1, e2) ->
let e1 = aux e1 in
let e2 = aux e2 in
concat e1 e2
| Binder (b, vars, e) ->
let e = aux e in
binder b vars e
in
aux e

module Set = struct
include PatriciaTree.MakeHashconsedSet (Key) ()

let hash = to_int

let pp fmt v =
Fmt.pf fmt "@[<hov 1>%a@]"
(pretty ~pp_sep:(fun fmt () -> Fmt.pf fmt "@;") pp)
v

let get_symbols (set : t) =
let tbl = Hashtbl.create 64 in
let rec symbols hte =
match view hte with
| Val _ | Loc _ -> ()
| Ptr { offset; _ } -> symbols offset
| Symbol s -> Hashtbl.replace tbl s ()
| List es -> List.iter symbols es
| App (_, es) -> List.iter symbols es
| Unop (_, _, e1) -> symbols e1
| Binop (_, _, e1, e2) ->
symbols e1;
symbols e2
| Triop (_, _, e1, e2, e3) ->
symbols e1;
symbols e2;
symbols e3
| Relop (_, _, e1, e2) ->
symbols e1;
symbols e2
| Cvtop (_, _, e) -> symbols e
| Naryop (_, _, es) -> List.iter symbols es
| Extract (e, _, _) -> symbols e
| Concat (e1, e2) ->
symbols e1;
symbols e2
| Binder (_, vars, e) ->
List.iter symbols vars;
symbols e
in
iter symbols set;
Hashtbl.fold (fun k () acc -> k :: acc) tbl []

let map f set =
fold
(fun elt set ->
let elt = f elt in
add elt set )
set empty

let inline_symbol_values symbol_map set =
map (inline_symbol_values symbol_map) set
end
11 changes: 11 additions & 0 deletions src/smtml/expr_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ module type S = sig
an error if the expression is not a relational operation. *)
val negate_relop : t -> t

(** [inline_symbol_values symbol_map e] replaces each symbol [e] expressions
of [set] by its image in [symbol_map]. *)
val inline_symbol_values : t Symbol.Map.t -> t -> t

(** {1 Pretty Printing} *)

(** [pp fmt term] prints a term in a human-readable format using the formatter
Expand Down Expand Up @@ -420,6 +424,9 @@ module type S = sig
of {!to_int}. *)
val iter : (elt -> unit) -> t -> unit

(** [map f set] maps all elements of [set] to their image by [f]. *)
val map : (elt -> elt) -> t -> t

(** [filter f set] is the subset of [set] that only contains the elements
that satisfy [f]. [f] is called in the unsigned order of {!to_int}. *)
val filter : (elt -> bool) -> t -> t
Expand Down Expand Up @@ -518,6 +525,10 @@ module type S = sig
(** [get_symbols exprs] extracts all symbolic variables from a list of
expressions. *)
val get_symbols : t -> Symbol.t list

(** [inline_symbol_values symbol_map set] replaces each symbol in all
expressions of [set] by its image in [symbol_map]. *)
val inline_symbol_values : elt Symbol.Map.t -> t -> t
end

(** {1 Bitvectors} *)
Expand Down
31 changes: 31 additions & 0 deletions test/unit/test_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,36 @@ let test_simplify =
; "test_simplify_ptr" >:: test_simplify_ptr
]

let test_inline_symbol_values_empty (_ : test_ctxt) =
let symbol_map = Symbol.Map.empty in
let e =
let ty = Ty.Ty_bitv 32 in
Infix.symbol "x" ty
in
let e' = Expr.inline_symbol_values symbol_map e in
(* We should not have changed the symbol value, and it should even stay physically equal to its original value. *)
assert (e == e')

let test_inline_symbol_values_replace_one (_ : test_ctxt) =
let n = Infix.int 42 in
let e' =
let x =
let ty = Ty.Ty_bitv 32 in
Symbol.make ty "x"
in
let symbol_map = Symbol.Map.add x n Symbol.Map.empty in
let e = Expr.symbol x in
Expr.inline_symbol_values symbol_map e
in
(* e should now be equal to n because symbol x should have been replaced by n *)
assert (e' == n)

let test_inline_symbol_values =
[ "test_inline_symbol_values_empty" >:: test_inline_symbol_values_empty
; "test_inline_symbol_values_replace_one"
>:: test_inline_symbol_values_replace_one
]

let test_suite =
"Expression unit tests"
>::: [ "test_hc" >:: test_hc
Expand All @@ -660,6 +690,7 @@ let test_suite =
; "test_cvtop" >::: test_cvtop
; "test_naryop" >::: test_naryop
; "test_simplify" >::: test_simplify
; "test_inline_symbol_values" >::: test_inline_symbol_values
]

let () = run_test_tt_main test_suite
Loading