From 5a542d158d3bde832e38b65ad5347299fbe7ee32 Mon Sep 17 00:00:00 2001 From: Pierre Goutagny Date: Thu, 17 Jun 2021 15:58:29 +0200 Subject: Simplify tunneling factorisation The recursive module definitions required unnecessarily long expicit signatures for little added legibility. --- backend/RTLTunnelingaux.ml | 65 ++++------- backend/Tunnelinglibs.ml | 276 ++++++++++++++++++++++++--------------------- 2 files changed, 173 insertions(+), 168 deletions(-) (limited to 'backend') diff --git a/backend/RTLTunnelingaux.ml b/backend/RTLTunnelingaux.ml index 5fe327d4..9333e357 100644 --- a/backend/RTLTunnelingaux.ml +++ b/backend/RTLTunnelingaux.ml @@ -26,8 +26,6 @@ open Maps open Camlcoq open Tunnelinglibs -let nopcounter = ref 0 - module LANG = struct type code_unit = RTL.instruction type funct = RTL.coq_function @@ -40,60 +38,44 @@ module OPT = struct let final_dump = false end -module rec T: sig - val get_node: cfg -> positive -> node - val set_branch: cfg -> positive -> node -> unit - val debug: ('a, out_channel, unit) format -> 'a - val string_of_labeli: ('a, node) Hashtbl.t -> 'a -> string - exception BugOnPC of int - val branch_target: RTL.coq_function -> (positive * Z.t) PTree.t - -end = Tunnelinglibs.Tunneling(LANG)(OPT)(FUNS) - -and FUNS: sig - val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list - - val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool +module Partial = Tunnelinglibs.Tunneling(LANG)(OPT) - val fn_code: LANG.funct -> LANG.code_unit PTree.t - val fn_entrypoint: LANG.funct -> positive - - val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit -end = struct +module FUNS = struct let build_simplified_cfg c acc pc i = match i with | Inop s -> - let ns = T.get_node c s in - T.set_branch c pc ns; + let ns = get_node c s in + set_branch c pc ns; incr nopcounter; acc | Icond (_, _, s1, s2, _) -> c.num_rems <- c.num_rems + 1; - let ns1 = T.get_node c s1 in - let ns2 = T.get_node c s2 in - let npc = T.get_node c pc in + let ns1 = get_node c s1 in + let ns2 = get_node c s2 in + let npc = get_node c pc in npc.inst <- COND(ns1, ns2); npc::acc | _ -> acc let print_code_unit c println (pc, i) = match i with - | Inop s -> (if println then T.debug "\n"); T.debug "%d:Inop %d %s\n" pc (P.to_int s) (T.string_of_labeli c.nodes pc); false - | Icond (_, _, s1, s2, _) -> (if println then T.debug "\n"); T.debug "%d:Icond (%d,%d) %s\n" pc (P.to_int s1) (P.to_int s2) (T.string_of_labeli c.nodes pc); false - | _ -> T.debug "%d " pc; true + | Inop s -> (if println then Partial.debug "\n"); + Partial.debug "%d:Inop %d %s\n" pc (P.to_int s) (string_of_labeli c.nodes pc); + false + | Icond (_, _, s1, s2, _) -> (if println then Partial.debug "\n"); + Partial.debug "%d:Icond (%d,%d) %s\n" pc (P.to_int s1) (P.to_int s2) (string_of_labeli c.nodes pc); + false + | _ -> Partial.debug "%d " pc; + true let fn_code f = f.fn_code let fn_entrypoint f = f.fn_entrypoint + (*************************************************************) (* Copy-paste of the extracted code of the verifier *) (* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *) - let get td pc = - match PTree.get pc td with - | Some p -> let (t0, d) = p in (t0, d) - | None -> (pc, Z.of_uint 0) - let check_code_unit td pc i = match PTree.get pc td with | Some p -> @@ -105,8 +87,8 @@ end = struct if peq tpc ts then if zlt ds dpc0 then () - else raise (T.BugOnPC (P.to_int pc)) - else raise (T.BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) | Icond (_, _, s1, s2, _) -> let (ts1, ds1) = get td s1 in let (ts2, ds2) = get td s2 in @@ -115,15 +97,16 @@ end = struct then if zlt ds1 dpc0 then if zlt ds2 dpc0 then () - else raise (T.BugOnPC (P.to_int pc)) - else raise (T.BugOnPC (P.to_int pc)) - else raise (T.BugOnPC (P.to_int pc)) - else raise (T.BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) + else raise (BugOnPC (P.to_int pc)) | _ -> - raise (T.BugOnPC (P.to_int pc)) end + raise (BugOnPC (P.to_int pc)) end | None -> () end +module T = Partial.T(FUNS) let branch_target = T.branch_target diff --git a/backend/Tunnelinglibs.ml b/backend/Tunnelinglibs.ml index e1e61d68..1bb35f7a 100644 --- a/backend/Tunnelinglibs.ml +++ b/backend/Tunnelinglibs.ml @@ -39,51 +39,24 @@ and node = { mutable tag: int } -type cfg_node = (int, node) Hashtbl.t - type positive = P.t type integer = Z.t (* type of the (simplified) CFG *) type cfg = { - nodes: cfg_node; + nodes: (int, node) Hashtbl.t; mutable rems: node list; (* remaining conditions that may become lbranch or not *) mutable num_rems: int; mutable iter_num: int (* number of iterations in elimination of conditions *) } -module Tunneling = functor - (LANG: sig - type code_unit (* the type of a node of the code cfg (an instruction or a bblock *) - type funct - end) - (OPT: sig - val langname: string - val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *) - val debug_flag: bool ref - val final_dump: bool (* set to true to have a more verbose debugging *) - end) - (FUNS: sig - (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *) - val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list - - val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool - - val fn_code: LANG.funct -> LANG.code_unit PTree.t - val fn_entrypoint: LANG.funct -> positive - - val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit - end) --> struct - -let debug fmt = - if !OPT.debug_flag then Printf.eprintf fmt - else Printf.ifprintf stderr fmt - exception BugOnPC of int -let lab_i (n: node): int = fst n.lab -let lab_p (n: node): P.t = snd n.lab +(* keeps track of the total number of nops seen, for debugging purposes *) +let nopcounter = ref 0 + +(* General functions that do not require language-specific context, and that + are used for building language-specific functions *) let rec target c n = (* inspired from the "find" of union-find algorithm *) match n.inst with @@ -126,41 +99,14 @@ let set_branch c p s = let n = { lab = (li,p); inst = BRANCH s; link = target c s; dist = 0; tag = 0 } in Hashtbl.add c.nodes li n +let get td pc = + match PTree.get pc td with + | Some p -> let (t0, d) = p in (t0, d) + | None -> (pc, Z.of_uint 0) +let lab_i (n: node): int = fst n.lab +let lab_p (n: node): P.t = snd n.lab -(* try to change a condition into a branch -[acc] is the current accumulator of conditions to consider in the next iteration of repeat_change_cond -*) -let try_change_cond c acc pc = - match pc.inst with - | COND(s1,s2) -> - let ts1 = target c s1 in - let ts2 = target c s2 in - if ts1 == ts2 then ( - pc.link <- ts1; - c.num_rems <- c.num_rems - 1; - acc - ) else - pc::acc - | _ -> raise (BugOnPC (lab_i pc)) (* COND expected *) - -(* repeat [try_change_cond] until no condition is changed into a branch *) -let rec repeat_change_cond c = - c.iter_num <- c.iter_num + 1; - debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems); - let old = c.num_rems in - c.rems <- List.fold_left (try_change_cond c) [] c.rems; - let curr = c.num_rems in - let continue = - match OPT.limit_tunneling with - | Some n -> curr < old && c.iter_num < n - | None -> curr < old - in - if continue - then repeat_change_cond c - - -(* compute the final distance of each nop nodes to its target *) let undef_dist = -1 let self_dist = undef_dist-1 let rec dist n = @@ -176,29 +122,6 @@ let rec dist n = ) else if n.dist=self_dist then raise (BugOnPC (lab_i n)) else n.dist -let final_export f c = - let count = ref 0 in - let filter_nops_init_dist _ n acc = - let tn = target c n in - if tn == n - then ( - n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *) - acc - ) else ( - n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *) - count := !count+1; - n::acc - ) - in - let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in - let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in - debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n" - OPT.langname !count; - res - -(*********************************************) -(*** START: printing and debugging functions *) - let string_of_labeli nodes ipc = try let pc = Hashtbl.find nodes ipc in @@ -208,43 +131,142 @@ let string_of_labeli nodes ipc = with Not_found -> "" -let print_cfg (f: LANG.funct) c = - let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in - Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a; - let ep = P.to_int (FUNS.fn_entrypoint f) in - debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep); - let println = Array.fold_left (FUNS.print_code_unit c) false a in - (if println then debug "\n");debug "remaining cond:"; - List.iter (fun n -> debug "%d " (lab_i n)) c.rems; - debug "\n" - -(*************************************************************) -(* Copy-paste of the extracted code of the verifier *) -(* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *) - -(** val check_code : coq_UF -> code -> unit res **) - -let check_code td c = - PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (()) - -(*** END: copy-paste & debugging functions *******) - -let branch_target f = - debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname; - if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n"; - let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in - c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) []; - repeat_change_cond c; - let res = final_export f c in - if !OPT.debug_flag then ( - try - check_code res (FUNS.fn_code f); - if OPT.final_dump then print_cfg f c; - with e -> ( - print_cfg f c; - check_code res (FUNS.fn_code f) - ) - ); - res +(* + * When given the necessary types and options as context, and then some + * language-specific functions that cannot be factorised between LTL and RTL, the + * `Tunneling` functor returns a module containing the corresponding + * `branch_target` function. + *) + +module Tunneling = functor + (* Language-specific types *) + (LANG: sig + type code_unit (* the type of a node of the code cfg (an instruction or a bblock *) + type funct (* type of internal functions *) + end) + (* Compilation options for debugging *) + (OPT: sig + val langname: string + val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *) + val debug_flag: bool ref + val final_dump: bool (* set to true to have a more verbose debugging *) + end) + -> struct + + (* The `debug` function uses values from `OPT`, and is used in functions passed to `F` + so it must be defined between the two *) + let debug fmt = + if !OPT.debug_flag then Printf.eprintf fmt + else Printf.ifprintf stderr fmt + + module T + (* Language-specific functions *) + (FUNS: sig + (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *) + val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list + val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool + val fn_code: LANG.funct -> LANG.code_unit PTree.t + val fn_entrypoint: LANG.funct -> positive + val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit + end) + (* only export what's needed *) + : sig val branch_target: LANG.funct -> (positive * integer) PTree.t end + = struct + + (* try to change a condition into a branch [acc] is the current accumulator of + conditions to consider in the next iteration of repeat_change_cond *) + let try_change_cond c acc pc = + match pc.inst with + | COND(s1,s2) -> + let ts1 = target c s1 in + let ts2 = target c s2 in + if ts1 == ts2 then ( + pc.link <- ts1; + c.num_rems <- c.num_rems - 1; + acc + ) else + pc::acc + | _ -> raise (BugOnPC (lab_i pc)) (* COND expected *) + + (* repeat [try_change_cond] until no condition is changed into a branch *) + let rec repeat_change_cond c = + c.iter_num <- c.iter_num + 1; + debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems); + let old = c.num_rems in + c.rems <- List.fold_left (try_change_cond c) [] c.rems; + let curr = c.num_rems in + let continue = + match OPT.limit_tunneling with + | Some n -> curr < old && c.iter_num < n + | None -> curr < old + in + if continue + then repeat_change_cond c + + + (*********************************************) + (*** START: printing and debugging functions *) + + let print_cfg (f: LANG.funct) c = + let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in + Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a; + let ep = P.to_int (FUNS.fn_entrypoint f) in + debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep); + let println = Array.fold_left (FUNS.print_code_unit c) false a in + (if println then debug "\n");debug "remaining cond:"; + List.iter (fun n -> debug "%d " (lab_i n)) c.rems; + debug "\n" + + + (*************************************************************) + (* Copy-paste of the extracted code of the verifier *) + (* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *) + + (** val check_code : coq_UF -> code -> unit res **) + + let check_code td c = + PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (()) + + (*** END: copy-paste & debugging functions *******) + + (* compute the final distance of each nop nodes to its target *) + let final_export f c = + let count = ref 0 in + let filter_nops_init_dist _ n acc = + let tn = target c n in + if tn == n + then ( + n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *) + acc + ) else ( + n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *) + count := !count+1; + n::acc + ) + in + let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in + let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in + debug "* %sTunneling.branch_target: initial number of nops = %d\n" OPT.langname !nopcounter; + debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n" OPT.langname !count; + res + + let branch_target f = + debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname; + if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n"; + let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in + c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) []; + repeat_change_cond c; + let res = final_export f c in + if !OPT.debug_flag then ( + try + check_code res (FUNS.fn_code f); + if OPT.final_dump then print_cfg f c; + with e -> ( + print_cfg f c; + check_code res (FUNS.fn_code f) + ) + ); + res + end end -- cgit