diff options
Diffstat (limited to 'src/lfsc/ast.ml')
-rw-r--r-- | src/lfsc/ast.ml | 961 |
1 files changed, 961 insertions, 0 deletions
diff --git a/src/lfsc/ast.ml b/src/lfsc/ast.ml new file mode 100644 index 0000000..29a4afc --- /dev/null +++ b/src/lfsc/ast.ml @@ -0,0 +1,961 @@ +(**************************************************************************) +(* *) +(* SMTCoq *) +(* Copyright (C) 2011 - 2019 *) +(* *) +(* See file "AUTHORS" for the list of authors *) +(* *) +(* This file is distributed under the terms of the CeCILL-C licence *) +(* *) +(**************************************************************************) + + +open Format + +exception CVC4Sat + +let debug = + (* true *) + false + +(********************************) +(* Type definitions for the AST *) +(********************************) + +type mpz = Big_int.big_int +type mpq = Num.num + + +type name = Name of Hstring.t | S_Hole of int +type symbol = { sname : name; stype : term } + +and dterm = + | Type + | Kind + | Mpz + | Mpq + | Const of symbol + | App of term * term list + | Int of mpz + | Rat of mpq + | Pi of symbol * term + | Lambda of symbol * term + | Hole of int + | Ptr of term + | SideCond of Hstring.t * term list * term * term + +and term = { mutable value: dterm; ttype: term } +(* TODO: remove type annotations in terms *) + +type command = + | Check of term + | Define of Hstring.t * term + | Declare of Hstring.t * term + + +type proof = command list + + +module H = struct + let holds = Hstring.make "holds" + let th_holds = Hstring.make "th_holds" + let mp_add = Hstring.make "mp_add" + let mp_mul = Hstring.make "mp_mul" + let uminus = Hstring.make "~" + let eq = Hstring.make "=" +end + + +let is_rule t = + match t.ttype.value with + | App ({value=Const{sname=Name n}}, _) -> n == H.holds || n == H.th_holds + | _ -> false + + +let rec deref t = match t.value with + | Ptr t -> deref t + | _ -> t + + +let value t = (deref t).value + + +let ttype t = deref (deref t).ttype + + +let rec name c = match value c with + | Const {sname=Name n} -> Some n + | _ -> None + + +let rec app_name r = match value r with + | App ({value=Const{sname=Name n}}, args) -> Some (n, args) + | _ -> None + + +(*******************) +(* Pretty printing *) +(*******************) + +let address_of (x:'a) : nativeint = + if Obj.is_block (Obj.repr x) then + Nativeint.shift_left (Nativeint.of_int (Obj.magic x)) 1 (* magic *) + else + invalid_arg "Can only find address of boxed values." + +let rec print_symbol fmt { sname } = match sname with + | Name n -> Hstring.print fmt n + | S_Hole i -> fprintf fmt "_s%d" i + +and print_tval pty fmt t = match t.value with + | Type -> fprintf fmt "type" + | Kind -> fprintf fmt "kind" + | Mpz -> fprintf fmt "mpz" + | Mpq -> fprintf fmt "mpz" + | Const s -> print_symbol fmt s + | App (f, args) when pty && is_rule t -> + let color = (Hashtbl.hash f.value mod 216) + 16 in + let op, cl = sprintf "\x1b[38;5;%dm" color, "\x1b[0m" in + fprintf fmt "@[@<0>%s(%a@<0>%s%a@<0>%s)@,@<0>%s@]" + op + (print_tval false) f + cl + (fun fmt -> List.iter (fprintf fmt "@ %a" (print_term pty))) args + op cl + | App (f, args) -> + fprintf fmt "@[(%a%a)@,@]" + (print_tval false) f + (fun fmt -> List.iter (fprintf fmt "@ %a" (print_term pty))) args + + | Int n -> pp_print_string fmt (Big_int.string_of_big_int n) + | Rat q -> pp_print_string fmt (Num.string_of_num q) + | Pi (s, t) -> + fprintf fmt "(! %a@ %a@ %a)@," + print_symbol s + (print_term false) s.stype + (print_term pty) t + | Lambda (s, t) -> + fprintf fmt "(%% %a@ %a@ %a)@," print_symbol s (print_term pty) s.stype + (print_term pty) t + | Hole i -> + if false && debug then + fprintf fmt "_%d[%nx]" i (address_of t) + else + fprintf fmt "_%d" i + + | Ptr t when (* true || *) debug -> fprintf fmt "*%a" (print_term pty) t + + | Ptr t -> print_term pty fmt t + + | SideCond (name, args, expected, t) -> + fprintf fmt "(! _ (^ (%a%a)@ %a)@ %a)" + Hstring.print name + (fun fmt -> List.iter (fprintf fmt "@ %a" (print_term pty))) args + (print_term pty) expected + (print_term pty) t + + +and print_term pty fmt t = match t with + | {value = Type | Kind | Ptr _ | Const _} + | {ttype = {value = Type | Kind | Const _ | Ptr _}} -> + print_tval pty fmt t + | _ when t.ttype == t -> + print_tval pty fmt t + (* | _ when pty -> *) + (* fprintf fmt "[@[%a:%a@]]" (print_tval pty) t (print_term pty) t.ttype *) + | _ when pty && is_rule t -> + let op, cl = "\x1b[30m", "\x1b[0m" in + fprintf fmt "@\n@[@<0>%s(: %a@<0>%s@\n%a@<0>%s)@<0>%s@,@]" + op (print_term false) t.ttype cl (print_tval pty) t op cl + (* | _ when pty -> *) + (* fprintf fmt "@[(:@ %a@ %a)@]" *) + (* (print_term false) t.ttype (print_tval pty) t *) + (* | _ when pty -> *) + (* fprintf fmt "@[%a\x1b[30m:%a\x1b[0m)@]" *) + (* (print_tval pty) t (print_term false) t.ttype *) + | _ -> + fprintf fmt "@[%a@]" (print_tval pty) t + + +let print_term_type = print_term true +let print_term = print_term false + +let print_command fmt = function + | Check t -> + fprintf fmt "(check@ (:@\n@\n %a@ @\n@\n%a))" + print_term t.ttype print_term_type t + | Define (s, t) -> + fprintf fmt "(define %a@ %a)" Hstring.print s print_term t + | Declare (s, t) -> + fprintf fmt "(declare %a@ %a)" Hstring.print s print_term t + +let print_proof fmt = + List.iter (fprintf fmt "@[<1>%a@]@\n@." print_command) + + + +let compare_symbol s1 s2 = match s1.sname, s2.sname with + | Name n1, Name n2 -> Hstring.compare n1 n2 + | Name _, _ -> -1 + | _, Name _ -> 1 + | S_Hole i1, S_Hole i2 -> Pervasives.compare i1 i2 + + +let rec compare_term ?(mod_eq=false) t1 t2 = match t1.value, t2.value with + | Ptr t1, _ -> compare_term ~mod_eq t1 t2 + | _, Ptr t2 -> compare_term ~mod_eq t1 t2 + | Type, Type | Kind, Kind | Mpz, Mpz | Mpq, Mpz -> 0 + | Type, _ -> -1 | _, Type -> 1 + | Kind, _ -> -1 | _, Kind -> 1 + | Mpz, _ -> -1 | _, Mpz -> 1 + | Mpq, _ -> -1 | _, Mpq -> 1 + | Int n1, Int n2 -> Big_int.compare_big_int n1 n2 + | Int _, _ -> -1 | _, Int _ -> 1 + | Rat q1, Rat q2 -> Num.compare_num q1 q2 + | Rat _, _ -> -1 | _, Rat _ -> 1 + | Const s1, Const s2 -> compare_symbol s1 s2 + | Const _, _ -> -1 | _, Const _ -> 1 + | App ({value=Const{sname=Name n1}}, [ty1; a1; b1]), + App ({value=Const{sname=Name n2}}, [ty2; a2; b2]) + when n1 == H.eq && n2 == H.eq && mod_eq -> + let c = compare_term ~mod_eq ty1 ty2 in + if c <> 0 then c + else + let ca1a2 = compare_term ~mod_eq a1 a2 in + let ca1b2 = compare_term ~mod_eq a1 b2 in + let cb1b2 = compare_term ~mod_eq b1 b2 in + let cb1a2 = compare_term ~mod_eq b1 a2 in + if ca1a2 = 0 && cb1b2 = 0 then 0 + else if ca1b2 = 0 && cb1a2 = 0 then 0 + else if ca1a2 <> 0 then ca1a2 else cb1b2 + | App (f1, l1), App (f2, l2) -> + let c = compare_term ~mod_eq f1 f2 in + if c <> 0 then c else + compare_term_list ~mod_eq l1 l2 + | App _, _ -> -1 | _, App _ -> 1 + + | Pi (s1, t1), Pi (s2, t2) -> + let c = compare_symbol s1 s2 in + if c <> 0 then c + else compare_term ~mod_eq t1 t2 + | Pi _, _ -> -1 | _, Pi _ -> 1 + + | Lambda (s1, t1), Lambda (s2, t2) -> + let c = compare_symbol s1 s2 in + if c <> 0 then c + else compare_term ~mod_eq t1 t2 + | Lambda _, _ -> -1 | _, Lambda _ -> 1 + + (* ignore side conditions *) + | SideCond (_, _, _, t), _ -> compare_term ~mod_eq t t2 + | _, SideCond (_, _, _, t) -> compare_term ~mod_eq t1 t + + | Hole i1, Hole i2 -> Pervasives.compare i1 i2 + + +and compare_term_list ?(mod_eq=false) l1 l2 = match l1, l2 with + | [], [] -> 0 + | [], _ -> -1 + | _, [] -> 1 + | t1 :: r1, t2 :: r2 -> + let c = compare_term ~mod_eq t1 t2 in + if c <> 0 then c + else compare_term_list ~mod_eq r1 r2 + + +let rec hash_term t = match t.value with + | Ptr t -> hash_term t + | v -> Hashtbl.hash_param 100 500 v + + +module Term = struct + type t = term + let compare = compare_term ~mod_eq:false + let equal x y = compare_term x y = 0 + let hash t = Hashtbl.hash_param 10 100 t.value (* hash_term *) + (* let hasht = Hashtbl.hash *) + (* let rec hash = *) + (* let cpt = ref 0 in *) + (* fun hh t -> *) + (* incr cpt; *) + (* if !cpt > 10 then hh else *) + (* hh + *) + (* let v = t.value in *) + (* match v with *) + (* | Hole _ | Type | Kind | Mpz | Mpq | Int _ | Rat _ | Const _ -> hasht v *) + (* | SideCond (_, args, exp, t) -> *) + (* List.fold_left (fun acc t -> hash hh t + 31*acc) (hash hh t) args *) + (* | App (f, args) -> *) + (* List.fold_left (fun acc t -> hash hh t + 31*acc) (hash hh f) args *) + (* | Pi (s, x) -> ((Hashtbl.hash s) + 31*(hash hh x)) * 7 *) + (* | Lambda (s, x) -> ((Hashtbl.hash s) + 31*(hash hh x)) * 9 *) + (* | Ptr t' -> 0 *) + (* (\* t.value <- t'.value; *\) *) + (* (\* hash hh (deref t') *\) *) + (* let hash = hash 0 *) +end + + + + +let rec holes_address acc t = match t.value with + | Hole i -> (i, t) :: acc + | Type | Kind | Mpz | Mpq | Int _ | Rat _ -> acc + | SideCond (name, args, exp, t) -> acc + | Const _ -> acc + | App (f, args) -> + List.fold_left holes_address acc args + | Pi (s, x) -> holes_address acc x + | Lambda (s, x) -> holes_address acc x + | Ptr t -> holes_address acc t + +let holes_address = holes_address [] + + +let check_holes_integrity where h1 h2 = + List.iter (fun (i, a) -> + List.iter (fun (j, b) -> + if j = i && a != b then + ( + eprintf "\n%s: Hole _%d was at %nx, now at %nx\n@." where i + (address_of a) (address_of b); + (* eprintf "\n%s: Hole _%d has changed\n@." where i; *) + assert false) + ) h2 + ) h1 + +let check_term_integrity where t = + let h = holes_address t in + check_holes_integrity (where ^ "term has != _") h h + + + +let eq_name s1 s2 = match s1, s2 with + | S_Hole i1, S_Hole i2 -> i1 == i2 + | Name n1, Name n2 -> n1 == n2 + | _ -> false + +module HN = Hashtbl.Make (struct + type t = name + let equal = eq_name + let hash = function + | S_Hole i -> i * 7 + | Name n -> Hstring.hash n * 9 + end) + +let symbols = HN.create 21 +let register_symbol s = HN.add symbols s.sname s.stype +let remove_symbol s = HN.remove symbols s.sname + +let definitions = HN.create 21 +let add_definition n t = HN.add definitions n t +let remove_definition n = HN.remove definitions n + + +exception TypingError of term * term + + +(**************************) +(* Predefined terms/types *) +(**************************) + + +let rec kind = { value = Kind; ttype = kind } + +let lfsc_type = { value = Type; ttype = kind } + +let mpz = { value = Mpz; ttype = lfsc_type } + +let mpq = { value = Mpq; ttype = lfsc_type } + +let mk_mpz n = { value = Int n; ttype = mpz } + +let mpz_of_int n = { value = Int (Big_int.big_int_of_int n); ttype = mpz } + +let mk_mpq n = { value = Rat n; ttype = mpq } + + +let mk_symbol s stype = + { sname = Name (Hstring.make s) ; stype } + (* { sname = Name (String.concat "." (List.rev (n :: scope))) ; stype } *) + +let mk_symbol_hole = + let cpt = ref 0 in + fun stype -> + incr cpt; + { sname = S_Hole !cpt; stype } + +let is_hole = function { value = Hole _ } -> true | _ -> false + +let is_hole_symbol = function { sname = S_Hole _ } -> true | _ -> false + +let mk_hole = + let cpt = ref 0 in + fun ttype -> + incr cpt; + { value = Hole !cpt; ttype } + +(* let mk_rec_hole () = *) +(* let rec h = { value = Hole !cpt; ttype = h } in *) +(* h *) + +let mk_hole_hole () = + mk_hole (mk_hole lfsc_type) + + +(*****************************) +(* Side conditions callbacks *) +(*****************************) + +let callbacks_table = Hstring.H.create 7 + + +let mp_add x y = + (* eprintf "mp_add %a %a@." print_term x print_term y; *) + match value x, value y with + | Int xi, Int yi -> mk_mpz (Big_int.add_big_int xi yi) + | _ -> assert false + +let mp_mul x y = match value x, value y with + | Int xi, Int yi -> mk_mpz (Big_int.mult_big_int xi yi) + | _ -> assert false + +let uminus x = match value x with + | Int xi -> mk_mpz (Big_int.minus_big_int xi) + | _ -> assert false + + +let rec eval_arg x = match app_name x with + | Some (n, [x]) when n == H.uminus -> uminus (eval_arg x) + | Some (n, [x; y]) when n == H.mp_add -> mp_add (eval_arg x) (eval_arg y) + | Some (n, [x; y]) when n == H.mp_mul -> mp_mul (eval_arg x) (eval_arg y) + | _ -> x + + +let callback name l = + try + let f = Hstring.H.find callbacks_table name in + (* eprintf "apply %s ... @." name; *) + let l = List.map eval_arg l in + f l + with Not_found -> + failwith ("No side condition for " ^ Hstring.view name) + + + +(* type sc_check = String * term list * term *) + + +(* type sc_tree = *) +(* | SCEmpty *) +(* (\* | SCLeaf of sc_check *\) *) +(* | SCBranches of sc_check * sc_tree list *) + + +(* let sct = ref (SCEmpty) *) + + +let sc_to_check = ref [] + + + +(**********************************) +(* Smart constructors for the AST *) +(**********************************) + +module MSym = Map.Make (struct + type t = symbol + let compare = compare_symbol + end) + + +let empty_subst = MSym.empty + +let fresh_alpha = + let cpt = ref 0 in + fun ty -> incr cpt; + mk_symbol ("'a"^string_of_int !cpt) ty + + +let get_t ?(gen=true) sigma s = + try + let x = MSym.find s sigma in + if not gen && is_hole x then raise Not_found; + x + with Not_found -> try + HN.find definitions s.sname + with Not_found -> + { value = Const s; ttype = s.stype } + + +type substres = T of term | V of dterm | Same + + +let apply_subst_sym sigma s = + try + let x = MSym.find s sigma in + T x + with Not_found -> Same + (* try *) + (* T (Hashtbl.find definitions s) *) + (* with Not_found -> Same *) + + +let print_subst fmt sigma = + fprintf fmt "@[<v 1>["; + MSym.iter (fun s t -> + fprintf fmt "@ %a -> %a;" print_symbol s print_term t) sigma; + fprintf fmt " ]@]" + + +let rec apply_subst_val sigma tval = match tval with + | Type | Kind | Mpz | Mpq | Int _ | Rat _ | Hole _ -> Same + + (* | Ptr t -> *) + (* V (Ptr (apply_subst sigma t)) *) + (* | Ptr t -> apply_subst_val sigma t.value *) + + | Ptr t -> + T (apply_subst sigma t) + + | Const s when is_hole_symbol s -> Same (* raise Exit *) + | Const s -> apply_subst_sym sigma s + | App (f, args) -> + let nf = apply_subst sigma f in + let nargs = List.rev_map (apply_subst sigma) args |> List.rev in + if nf == f && List.for_all2 (==) nargs args then (* V tval *) Same + else + V (App(nf, nargs)) + + | Pi (s, x) -> + let s = { s with stype = apply_subst sigma s.stype } in + let sigma = MSym.remove s sigma in + let newx = apply_subst sigma x in + if x == newx then (* V tval *) Same + else + V (Pi (s, newx)) + + | Lambda (s, x) -> + let s = { s with stype = apply_subst sigma s.stype } in + let sigma = MSym.remove s sigma in + let newx = apply_subst sigma x in + if x == newx then (* V tval *) Same + else + V (Lambda (s, newx)) + + | SideCond (name, args, exp, t) -> + let nt = apply_subst sigma t in + let nexp = apply_subst sigma exp in + let nargs = List.rev_map (apply_subst sigma) args |> List.rev in + if nt == t && nexp == exp && List.for_all2 (==) nargs args then (* V tval *) Same + else + V (SideCond (name, nargs, nexp, nt)) + + + +and apply_subst sigma t = + match apply_subst_val sigma t.value with + | Same -> t + | T t -> t + | V value -> + let ttype = apply_subst sigma t.ttype in + if value == t.value && ttype == t.ttype then t + else { value; ttype } + + + +let get_real t = apply_subst MSym.empty t + + +let rec flatten_term_value t = match t.value with + | Hole _ | Type | Kind | Mpz | Mpq | Int _ | Rat _ -> () + | SideCond (_, args, exp, t) -> + List.iter flatten_term args; + flatten_term exp; + flatten_term t + | Const s -> flatten_term s.stype + | App (f, args) -> + flatten_term f; + List.iter flatten_term args + | Pi (s, x) | Lambda (s, x) -> + flatten_term s.stype; + flatten_term x + | Ptr t' -> + t.value <- (deref t').value + (* flatten_term t *) + + +and flatten_term t = + flatten_term_value t + (* ; *) + (* match t.value with *) + (* | Type | Kind -> () *) + (* | _ -> flatten_term t.ttype *) + + +let rec has_ptr_val t = match t.value with + | Hole _ | Type | Kind | Mpz | Mpq | Int _ | Rat _ -> false + | SideCond (_, args, exp, t) -> + List.exists has_ptr args || has_ptr exp || has_ptr t + | Const s -> has_ptr s.stype + | App (f, args) -> has_ptr f || List.exists has_ptr args + | Pi (s, x) | Lambda (s, x) -> has_ptr s.stype || has_ptr x + | Ptr _ -> true + +and has_ptr t = + has_ptr_val t || + match t.value with + | Type | Kind -> false + | _ -> has_ptr t.ttype + + +let add_subst x v sigma = MSym.add x v sigma + (* let sigma = List.rev_map (fun (y, w) -> y, apply_subst [x,v] w) sigma |> List.rev in *) + (* (x, apply_subst sigma v) :: sigma *) + + + +let rec occur_check subt t = + compare_term t subt = 0 + || + match t.value with + | Type | Kind | Mpz | Mpq | Int _ | Rat _ | Hole _ | Const _ -> false + + | Ptr t -> occur_check subt t + + | App (f, args) -> + occur_check subt f || + List.exists (occur_check subt) args + + | Pi (s, x) -> occur_check subt x + + | Lambda (s, x) -> occur_check subt x + + | SideCond (name, args, exp, t) -> + occur_check subt exp || + occur_check subt t || + List.exists (occur_check subt) args + + + + +let rec fill_hole sigma h t = + match h.value with + | Hole _ -> + if debug then + eprintf ">>>>> Fill hole @[%a@] with @[%a@]@." + print_term h print_term t; + let t' = apply_subst sigma t in + (* h.value <- t'.value; (\* KLUDGE *\) *) + if not (occur_check h t') then h.value <- Ptr (deref t'); + if debug then + eprintf ">>>>>>>>> @[%a@]@." print_term_type h; + fill_hole sigma h.ttype t'.ttype; + (* (try compat_with sigma t'.ttype h.ttype with _ -> ()); *) + | _ -> () + + + + +(* Raise TypingError if t2 is not compatible with t1 *) +(* apsub is false if we want to prevent application of substitutions *) +and compat_with1 ?(apsub=true) sigma t1 t2 = + if debug then ( + eprintf "compat_with(%b): @[<hov>%a@] and @[<hov>%a@]@." + apsub print_term t1 print_term t2; + eprintf " with sigma = %a@." print_subst sigma + ); + + match t1.value, t2.value with + | Type, Type + | Kind, Kind + | Mpz, Mpz + | Mpq, Mpz -> () + + | Int z1, Int z2 -> if not (Big_int.eq_big_int z1 z2) then raise Exit + | Rat q1, Rat q2 -> if not (Num.eq_num q1 q2) then raise Exit + + | Ptr t, _ -> compat_with1 ~apsub sigma t t2 + | _, Ptr t -> compat_with1 ~apsub sigma t1 t + + | Const s1, Const s2 -> + if apsub then + let a2 = get_t sigma s2 in + let a1 = get_t ~gen:(not (is_hole a2)) sigma s1 in + compat_with1 sigma ~apsub:false a1 a2 + else + if not (eq_name s1.sname s2.sname) then raise Exit + + | App (f1, args1), App (f2, args2) -> + compat_with1 sigma f1 f2; + List.iter2 (compat_with sigma) args1 args2 + + | Pi (s1, t1), Pi (s2, t2) -> + compat_with1 sigma s1.stype s2.stype; + let a = s2 in + let ta = { value = Const a; ttype = a.stype } in + compat_with1 (add_subst s1 ta sigma) t1 t2; + + | Lambda (s1, t1), Lambda (s2, t2) -> + compat_with sigma s1.stype s2.stype; + let a = s2 in + let ta = { value = Const a; ttype = a.stype } in + compat_with1 (add_subst s1 ta sigma) t1 t2; + + + | SideCond (name, args, expected, t1), _ -> + check_side_condition name + (List.rev_map (apply_subst sigma) args |> List.rev) + (apply_subst sigma expected); + compat_with1 sigma t1 t2 + + (* ignore side conditions on the right *) + | _, SideCond (name, args, expected, t2) -> + compat_with1 sigma t1 t2 + + | Hole i, Hole j when i = j -> () + (* failwith ("Cannot infer type of holes, too many holes.") *) + + | _, Hole _ -> fill_hole sigma t2 t1 + | Hole _, _ -> fill_hole sigma t1 t2 + + + | Const s, _ -> + if apsub then + let a = get_t sigma s in + compat_with1 sigma ~apsub:false a t2 + else + raise Exit + + | _, Const s -> + if apsub then + let a = get_t sigma s in + compat_with1 sigma ~apsub:false a t1 + else + raise Exit + + | _ -> raise Exit + + +and compat_with sigma t1 t2 = + try compat_with1 sigma t1 t2 + with Exit -> + raise (TypingError (apply_subst sigma t1, apply_subst sigma t2)) + + + +and term_equal t1 t2 = + try + compat_with empty_subst t1 t2; + true + with + | TypingError _ | Failure _ -> false + + + +and check_side_condition name l expected = + if debug then + eprintf "Adding side condition : (%a%a) =?= %a@." + Hstring.print name + (fun fmt -> List.iter (fprintf fmt "@ %a" print_term)) l + print_term expected; + (* if not (term_equal (callback name l) expected) then *) + (* failwith ("Side condition " ^ name ^ " failed"); *) + sc_to_check := (name, l, expected) :: !sc_to_check + + + +let rec ty_of_app sigma ty args = match ty.value, args with + | Pi (s, t), a :: rargs -> + let sigma = add_subst s a sigma in + compat_with sigma s.stype a.ttype; + ty_of_app sigma t rargs + + | SideCond (name, scargs, expected, t), args -> + check_side_condition name + (List.rev_map (apply_subst sigma) scargs |> List.rev) + (apply_subst sigma expected); + ty_of_app sigma t args + + | _, [] -> apply_subst sigma ty + | _ -> failwith ("Type of function not a pi-type.") + + +let mk_const x = + if debug then eprintf "mk_const %s@." x; + try + let stype = HN.find symbols (Name (Hstring.make x)) in + let s = mk_symbol x stype in + try + HN.find definitions s.sname + with Not_found -> { value = Const s; ttype = stype } + with Not_found -> failwith ("Symbol " ^ x ^ " is not declared.") + + +let symbol_to_const s = { value = Const s; ttype = s.stype } + + +let rec mk_app ?(lookup=true) sigma f args = + if debug then + eprintf "mk_App : %a@." (print_tval false) + { value = App (f, args); ttype = lfsc_type } ; + + match f.value, args with + | Lambda (x, r), a :: rargs -> + let sigma = MSym.remove x sigma in + mk_app (add_subst x a sigma) r rargs + + (* | Const {sname = Name "mp_add"}, [x; y] -> mp_add x y *) + + (* | Const {sname = Name "mp_mul"}, [x; y] -> mp_mul x y *) + + | Const s, _ when lookup -> + (* find the definition if it has one *) + let f = get_t sigma s in + mk_app ~lookup:false sigma f args + + | x, [] -> + (* Delayed beta-reduction *) + apply_subst sigma f + + | _ -> + (* TODO: check if empty_subst or sigma *) + { value = App (f, args); ttype = ty_of_app empty_subst f.ttype args } + + +let mk_app = mk_app empty_subst + + +let rec hole_nbs acc t = match value t with + | Hole nb -> nb :: acc + | App (f, args) -> List.fold_left hole_nbs (hole_nbs acc f) args + | Pi (s, x) | Lambda (s, x) -> hole_nbs acc x + | Ptr t -> hole_nbs acc t + | _ -> acc + + +let rec min_hole acc t = match value t with + | Hole nb -> + (match acc with Some n when nb < n -> Some nb | None -> Some nb | _ -> acc) + | App (f, args) -> List.fold_left min_hole (min_hole acc f) args + | Pi (s, x) | Lambda (s, x) -> min_hole acc x + | Ptr t -> min_hole acc t + | _ -> acc + + +let compare_int_opt m1 m2 = match m1, m2 with + | None, None -> 0 + | Some _, None -> -1 + | None, Some _ -> 1 + | Some n1, Some n2 -> compare n1 n2 + + +let compare_sc_checks (_, args1, exp1) (_, args2, exp2) = + let el1 = hole_nbs [] exp1 in + let el2 = hole_nbs [] exp2 in + + let al1 = List.fold_left hole_nbs [] args1 in + let al2 = List.fold_left hole_nbs [] args2 in + + if List.exists (fun n -> List.mem n al1) el2 then 1 + else if List.exists (fun n -> List.mem n al2) el1 then -1 + else if el1 = [] then 1 + else if el2 = [] then -1 + else compare el1 el2 + + +let sort_sc_checks l = List.fast_sort compare_sc_checks l + + +let run_side_conditions () = + (* List.iter (fun (name, l, expected) -> *) + (* eprintf "\nSorted side condition : (%s%a) =?= %a@." *) + (* name *) + (* (fun fmt -> List.iter (fprintf fmt "@ %a" print_term)) l *) + (* print_term expected; *) + (* ) (List.flatten !all_scs |> sort_sc_checks); *) + + List.iter (fun (name, l, expected) -> + let res = callback name l in + if not (term_equal res expected) then + failwith (asprintf "Side condition %a failed: Got %a, expected %a" + Hstring.print name print_term res print_term expected); + ) (sort_sc_checks !sc_to_check); + sc_to_check := []; + () + + +let mk_pi s t = + (* let s = if is_hole_symbol s then fresh_alpha s.stype else s in *) + { value = Pi (s, t); ttype = lfsc_type } + +let mk_lambda s t = + (* sc_to_check := List.rev !sc_to_check; *) + (* run_side_conditions (); *) + (* let s = if is_hole_symbol s then fresh_alpha s.stype else s in *) + { value = Lambda (s, t); + ttype = mk_pi s t.ttype } + + +let mk_ascr ty t = + if debug then + eprintf "\nMK ASCR:: should have type %a, has type %a\n@." + print_term ty print_term t.ttype; + compat_with empty_subst ty t.ttype; t + (* { t with ttype = ty } *) + + +let add_sc name args expected t = + { value = SideCond (Hstring.make name, args, expected, t); + ttype = t.ttype } + + +let mk_declare n ty = + let s = mk_symbol n ty in + register_symbol s + +let mk_define n t = + let s = mk_symbol n t.ttype in + register_symbol s; + add_definition s.sname t + + + +let mk_check t = run_side_conditions () + + +let clear_sc () = sc_to_check := [] + + + +let rec hash_term_mod_eq p = match p.value with + | App ({value=Const{sname=Name n}} as f, [ty; a; b]) + when n == H.eq && + compare_term ~mod_eq:true a b > 0 -> + Term.hash (mk_app f [ty; b; a]) + | App (f, args) -> + List.fold_left + (fun acc t -> 7*(acc + hash_term_mod_eq f)) 1 (f:: args) + | Pi (s, x) -> + (Hashtbl.hash_param 100 500 s + hash_term_mod_eq x) * 11 + | Lambda (s, x) -> + (Hashtbl.hash_param 100 500 s + hash_term_mod_eq x) * 13 + | _ -> Hashtbl.hash_param 100 500 p + + +module Term_modeq = struct + type t = term + let compare = compare_term ~mod_eq:true + let equal x y = compare_term ~mod_eq:true x y = 0 + let hash t = + (* eprintf "HASH: %a@." print_term t; *) + hash_term_mod_eq t +end + + +(* + Local Variables: + compile-command: "make" + indent-tabs-mode: nil + End: +*) |