diff --git a/src/smtml/expr.ml b/src/smtml/expr.ml index 4b387973..81f02783 100644 --- a/src/smtml/expr.ml +++ b/src/smtml/expr.ml @@ -26,9 +26,7 @@ and expr = module Expr = struct type t = expr - let list_eq (l1 : 'a list) (l2 : 'a list) : bool = - if List.compare_lengths l1 l2 = 0 then List.for_all2 phys_equal l1 l2 - else false + let list_eq (l1 : 'a list) (l2 : 'a list) : bool = List.equal phys_equal l1 l2 let equal (e1 : expr) (e2 : expr) : bool = match (e1, e2) with @@ -182,38 +180,28 @@ let rec is_symbolic (v : t) : bool = is_symbolic v1 || is_symbolic v2 || is_symbolic v3 | List vs | App (_, vs) | Naryop (_, _, vs) -> List.exists is_symbolic vs +let rec get_symbols_aux acc (hte : t) = + match view hte with + | Val _ -> acc + | Ptr { offset; _ } -> get_symbols_aux acc offset + | Symbol s -> s :: acc + | List es | App (_, es) | Naryop (_, _, es) -> + List.fold_left get_symbols_aux acc es + | Unop (_, _, e) | Cvtop (_, _, e) | Extract (e, _, _) -> + get_symbols_aux acc e + | Binop (_, _, e1, e2) | Relop (_, _, e1, e2) | Concat (e1, e2) -> + let acc = get_symbols_aux acc e1 in + get_symbols_aux acc e2 + | Triop (_, _, e1, e2, e3) -> + let acc = get_symbols_aux acc e1 in + let acc = get_symbols_aux acc e2 in + get_symbols_aux acc e3 + | Binder (_, vars, e) -> + let acc = List.fold_left get_symbols_aux acc vars in + get_symbols_aux acc e + let get_symbols (hte : t list) = - let tbl = Hashtbl.create 64 in - let rec symbols (hte : t) = - match view hte with - | Val _ -> () - | Ptr { offset; _ } -> symbols offset - | Symbol s -> Hashtbl.replace tbl s () - | List es -> List.iter symbols es - | App (_, es) -> List.iter symbols es - | Unop (_, _, e1) -> symbols e1 - | Binop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Triop (_, _, e1, e2, e3) -> - symbols e1; - symbols e2; - symbols e3 - | Relop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Cvtop (_, _, e) -> symbols e - | Naryop (_, _, es) -> List.iter symbols es - | Extract (e, _, _) -> symbols e - | Concat (e1, e2) -> - symbols e1; - symbols e2 - | Binder (_, vars, e) -> - List.iter symbols vars; - symbols e - in - List.iter symbols hte; - Hashtbl.fold (fun k () acc -> k :: acc) tbl [] + List.fold_left get_symbols_aux [] hte |> List.sort_uniq Symbol.compare let rec pp_with ~printer fmt (hte : t) = match view hte with @@ -823,37 +811,8 @@ module Set = struct v let get_symbols (set : t) = - let tbl = Hashtbl.create 64 in - let rec symbols hte = - match view hte with - | Val _ -> () - | Ptr { offset; _ } -> symbols offset - | Symbol s -> Hashtbl.replace tbl s () - | List es -> List.iter symbols es - | App (_, es) -> List.iter symbols es - | Unop (_, _, e1) -> symbols e1 - | Binop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Triop (_, _, e1, e2, e3) -> - symbols e1; - symbols e2; - symbols e3 - | Relop (_, _, e1, e2) -> - symbols e1; - symbols e2 - | Cvtop (_, _, e) -> symbols e - | Naryop (_, _, es) -> List.iter symbols es - | Extract (e, _, _) -> symbols e - | Concat (e1, e2) -> - symbols e1; - symbols e2 - | Binder (_, vars, e) -> - List.iter symbols vars; - symbols e - in - iter symbols set; - Hashtbl.fold (fun k () acc -> k :: acc) tbl [] + fold (fun x acc -> get_symbols_aux acc x) set [] + |> List.sort_uniq Symbol.compare let map f set = fold diff --git a/src/smtml/mappings.ml b/src/smtml/mappings.ml index 1c597413..d0d01fd2 100644 --- a/src/smtml/mappings.ml +++ b/src/smtml/mappings.ml @@ -843,13 +843,13 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct List.iter (fun sym -> let v = value model0 (Expr.symbol sym) in - Hashtbl.replace m sym v ) + Hashtbl.add m sym v ) symbols | None -> Smap.iter (fun (sym : Symbol.t) term -> let v = Encoder.value_of_term ~ctx model sym.ty term in - Hashtbl.replace m sym v ) + Hashtbl.add m sym v ) ctx ); m diff --git a/src/smtml/smtlib.ml b/src/smtml/smtlib.ml index dedeb80f..4406d28d 100644 --- a/src/smtml/smtlib.ml +++ b/src/smtml/smtlib.ml @@ -362,7 +362,7 @@ module Statement = struct let name = match Symbol.name id with Simple name -> name | _ -> assert false in - Hashtbl.replace custom_sorts name (Expr.ty t); + Hashtbl.add custom_sorts name (Expr.ty t); Echo "" let datatypes ?loc:_ = assert false