aboutsummaryrefslogtreecommitdiffstats
path: root/backend
diff options
context:
space:
mode:
authorPierre Goutagny <pierre.goutagny@ens-lyon.fr>2021-06-17 15:58:29 +0200
committerPierre Goutagny <pierre.goutagny@ens-lyon.fr>2021-06-17 15:58:29 +0200
commit5a542d158d3bde832e38b65ad5347299fbe7ee32 (patch)
treecd0fc578fe97cadbc3e3e5cb397bed3b84a22621 /backend
parent67bc93934e939e57c80ade4c37aaba1535222fa2 (diff)
downloadcompcert-kvx-5a542d158d3bde832e38b65ad5347299fbe7ee32.tar.gz
compcert-kvx-5a542d158d3bde832e38b65ad5347299fbe7ee32.zip
Simplify tunneling factorisation
The recursive module definitions required unnecessarily long expicit signatures for little added legibility.
Diffstat (limited to 'backend')
-rw-r--r--backend/RTLTunnelingaux.ml65
-rw-r--r--backend/Tunnelinglibs.ml276
2 files changed, 173 insertions, 168 deletions
diff --git a/backend/RTLTunnelingaux.ml b/backend/RTLTunnelingaux.ml
index 5fe327d4..9333e357 100644
--- a/backend/RTLTunnelingaux.ml
+++ b/backend/RTLTunnelingaux.ml
@@ -26,8 +26,6 @@ open Maps
open Camlcoq
open Tunnelinglibs
-let nopcounter = ref 0
-
module LANG = struct
type code_unit = RTL.instruction
type funct = RTL.coq_function
@@ -40,60 +38,44 @@ module OPT = struct
let final_dump = false
end
-module rec T: sig
- val get_node: cfg -> positive -> node
- val set_branch: cfg -> positive -> node -> unit
- val debug: ('a, out_channel, unit) format -> 'a
- val string_of_labeli: ('a, node) Hashtbl.t -> 'a -> string
- exception BugOnPC of int
- val branch_target: RTL.coq_function -> (positive * Z.t) PTree.t
-
-end = Tunnelinglibs.Tunneling(LANG)(OPT)(FUNS)
-
-and FUNS: sig
- val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list
-
- val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool
+module Partial = Tunnelinglibs.Tunneling(LANG)(OPT)
- val fn_code: LANG.funct -> LANG.code_unit PTree.t
- val fn_entrypoint: LANG.funct -> positive
-
- val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit
-end = struct
+module FUNS = struct
let build_simplified_cfg c acc pc i =
match i with
| Inop s ->
- let ns = T.get_node c s in
- T.set_branch c pc ns;
+ let ns = get_node c s in
+ set_branch c pc ns;
incr nopcounter;
acc
| Icond (_, _, s1, s2, _) ->
c.num_rems <- c.num_rems + 1;
- let ns1 = T.get_node c s1 in
- let ns2 = T.get_node c s2 in
- let npc = T.get_node c pc in
+ let ns1 = get_node c s1 in
+ let ns2 = get_node c s2 in
+ let npc = get_node c pc in
npc.inst <- COND(ns1, ns2);
npc::acc
| _ -> acc
let print_code_unit c println (pc, i) =
match i with
- | Inop s -> (if println then T.debug "\n"); T.debug "%d:Inop %d %s\n" pc (P.to_int s) (T.string_of_labeli c.nodes pc); false
- | Icond (_, _, s1, s2, _) -> (if println then T.debug "\n"); T.debug "%d:Icond (%d,%d) %s\n" pc (P.to_int s1) (P.to_int s2) (T.string_of_labeli c.nodes pc); false
- | _ -> T.debug "%d " pc; true
+ | Inop s -> (if println then Partial.debug "\n");
+ Partial.debug "%d:Inop %d %s\n" pc (P.to_int s) (string_of_labeli c.nodes pc);
+ false
+ | Icond (_, _, s1, s2, _) -> (if println then Partial.debug "\n");
+ Partial.debug "%d:Icond (%d,%d) %s\n" pc (P.to_int s1) (P.to_int s2) (string_of_labeli c.nodes pc);
+ false
+ | _ -> Partial.debug "%d " pc;
+ true
let fn_code f = f.fn_code
let fn_entrypoint f = f.fn_entrypoint
+
(*************************************************************)
(* Copy-paste of the extracted code of the verifier *)
(* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)
- let get td pc =
- match PTree.get pc td with
- | Some p -> let (t0, d) = p in (t0, d)
- | None -> (pc, Z.of_uint 0)
-
let check_code_unit td pc i =
match PTree.get pc td with
| Some p ->
@@ -105,8 +87,8 @@ end = struct
if peq tpc ts
then if zlt ds dpc0
then ()
- else raise (T.BugOnPC (P.to_int pc))
- else raise (T.BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
| Icond (_, _, s1, s2, _) ->
let (ts1, ds1) = get td s1 in
let (ts2, ds2) = get td s2 in
@@ -115,15 +97,16 @@ end = struct
then if zlt ds1 dpc0
then if zlt ds2 dpc0
then ()
- else raise (T.BugOnPC (P.to_int pc))
- else raise (T.BugOnPC (P.to_int pc))
- else raise (T.BugOnPC (P.to_int pc))
- else raise (T.BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
+ else raise (BugOnPC (P.to_int pc))
| _ ->
- raise (T.BugOnPC (P.to_int pc)) end
+ raise (BugOnPC (P.to_int pc)) end
| None -> ()
end
+module T = Partial.T(FUNS)
let branch_target = T.branch_target
diff --git a/backend/Tunnelinglibs.ml b/backend/Tunnelinglibs.ml
index e1e61d68..1bb35f7a 100644
--- a/backend/Tunnelinglibs.ml
+++ b/backend/Tunnelinglibs.ml
@@ -39,51 +39,24 @@ and node = {
mutable tag: int
}
-type cfg_node = (int, node) Hashtbl.t
-
type positive = P.t
type integer = Z.t
(* type of the (simplified) CFG *)
type cfg = {
- nodes: cfg_node;
+ nodes: (int, node) Hashtbl.t;
mutable rems: node list; (* remaining conditions that may become lbranch or not *)
mutable num_rems: int;
mutable iter_num: int (* number of iterations in elimination of conditions *)
}
-module Tunneling = functor
- (LANG: sig
- type code_unit (* the type of a node of the code cfg (an instruction or a bblock *)
- type funct
- end)
- (OPT: sig
- val langname: string
- val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *)
- val debug_flag: bool ref
- val final_dump: bool (* set to true to have a more verbose debugging *)
- end)
- (FUNS: sig
- (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *)
- val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list
-
- val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool
-
- val fn_code: LANG.funct -> LANG.code_unit PTree.t
- val fn_entrypoint: LANG.funct -> positive
-
- val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit
- end)
--> struct
-
-let debug fmt =
- if !OPT.debug_flag then Printf.eprintf fmt
- else Printf.ifprintf stderr fmt
-
exception BugOnPC of int
-let lab_i (n: node): int = fst n.lab
-let lab_p (n: node): P.t = snd n.lab
+(* keeps track of the total number of nops seen, for debugging purposes *)
+let nopcounter = ref 0
+
+(* General functions that do not require language-specific context, and that
+ are used for building language-specific functions *)
let rec target c n = (* inspired from the "find" of union-find algorithm *)
match n.inst with
@@ -126,41 +99,14 @@ let set_branch c p s =
let n = { lab = (li,p); inst = BRANCH s; link = target c s; dist = 0; tag = 0 } in
Hashtbl.add c.nodes li n
+let get td pc =
+ match PTree.get pc td with
+ | Some p -> let (t0, d) = p in (t0, d)
+ | None -> (pc, Z.of_uint 0)
+let lab_i (n: node): int = fst n.lab
+let lab_p (n: node): P.t = snd n.lab
-(* try to change a condition into a branch
-[acc] is the current accumulator of conditions to consider in the next iteration of repeat_change_cond
-*)
-let try_change_cond c acc pc =
- match pc.inst with
- | COND(s1,s2) ->
- let ts1 = target c s1 in
- let ts2 = target c s2 in
- if ts1 == ts2 then (
- pc.link <- ts1;
- c.num_rems <- c.num_rems - 1;
- acc
- ) else
- pc::acc
- | _ -> raise (BugOnPC (lab_i pc)) (* COND expected *)
-
-(* repeat [try_change_cond] until no condition is changed into a branch *)
-let rec repeat_change_cond c =
- c.iter_num <- c.iter_num + 1;
- debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems);
- let old = c.num_rems in
- c.rems <- List.fold_left (try_change_cond c) [] c.rems;
- let curr = c.num_rems in
- let continue =
- match OPT.limit_tunneling with
- | Some n -> curr < old && c.iter_num < n
- | None -> curr < old
- in
- if continue
- then repeat_change_cond c
-
-
-(* compute the final distance of each nop nodes to its target *)
let undef_dist = -1
let self_dist = undef_dist-1
let rec dist n =
@@ -176,29 +122,6 @@ let rec dist n =
) else if n.dist=self_dist then raise (BugOnPC (lab_i n))
else n.dist
-let final_export f c =
- let count = ref 0 in
- let filter_nops_init_dist _ n acc =
- let tn = target c n in
- if tn == n
- then (
- n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *)
- acc
- ) else (
- n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *)
- count := !count+1;
- n::acc
- )
- in
- let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in
- let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in
- debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n"
- OPT.langname !count;
- res
-
-(*********************************************)
-(*** START: printing and debugging functions *)
-
let string_of_labeli nodes ipc =
try
let pc = Hashtbl.find nodes ipc in
@@ -208,43 +131,142 @@ let string_of_labeli nodes ipc =
with
Not_found -> ""
-let print_cfg (f: LANG.funct) c =
- let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in
- Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a;
- let ep = P.to_int (FUNS.fn_entrypoint f) in
- debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep);
- let println = Array.fold_left (FUNS.print_code_unit c) false a in
- (if println then debug "\n");debug "remaining cond:";
- List.iter (fun n -> debug "%d " (lab_i n)) c.rems;
- debug "\n"
-
-(*************************************************************)
-(* Copy-paste of the extracted code of the verifier *)
-(* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)
-
-(** val check_code : coq_UF -> code -> unit res **)
-
-let check_code td c =
- PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (())
-
-(*** END: copy-paste & debugging functions *******)
-
-let branch_target f =
- debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname;
- if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n";
- let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in
- c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) [];
- repeat_change_cond c;
- let res = final_export f c in
- if !OPT.debug_flag then (
- try
- check_code res (FUNS.fn_code f);
- if OPT.final_dump then print_cfg f c;
- with e -> (
- print_cfg f c;
- check_code res (FUNS.fn_code f)
- )
- );
- res
+(*
+ * When given the necessary types and options as context, and then some
+ * language-specific functions that cannot be factorised between LTL and RTL, the
+ * `Tunneling` functor returns a module containing the corresponding
+ * `branch_target` function.
+ *)
+
+module Tunneling = functor
+ (* Language-specific types *)
+ (LANG: sig
+ type code_unit (* the type of a node of the code cfg (an instruction or a bblock *)
+ type funct (* type of internal functions *)
+ end)
+ (* Compilation options for debugging *)
+ (OPT: sig
+ val langname: string
+ val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *)
+ val debug_flag: bool ref
+ val final_dump: bool (* set to true to have a more verbose debugging *)
+ end)
+ -> struct
+
+ (* The `debug` function uses values from `OPT`, and is used in functions passed to `F`
+ so it must be defined between the two *)
+ let debug fmt =
+ if !OPT.debug_flag then Printf.eprintf fmt
+ else Printf.ifprintf stderr fmt
+
+ module T
+ (* Language-specific functions *)
+ (FUNS: sig
+ (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *)
+ val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list
+ val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool
+ val fn_code: LANG.funct -> LANG.code_unit PTree.t
+ val fn_entrypoint: LANG.funct -> positive
+ val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit
+ end)
+ (* only export what's needed *)
+ : sig val branch_target: LANG.funct -> (positive * integer) PTree.t end
+ = struct
+
+ (* try to change a condition into a branch [acc] is the current accumulator of
+ conditions to consider in the next iteration of repeat_change_cond *)
+ let try_change_cond c acc pc =
+ match pc.inst with
+ | COND(s1,s2) ->
+ let ts1 = target c s1 in
+ let ts2 = target c s2 in
+ if ts1 == ts2 then (
+ pc.link <- ts1;
+ c.num_rems <- c.num_rems - 1;
+ acc
+ ) else
+ pc::acc
+ | _ -> raise (BugOnPC (lab_i pc)) (* COND expected *)
+
+ (* repeat [try_change_cond] until no condition is changed into a branch *)
+ let rec repeat_change_cond c =
+ c.iter_num <- c.iter_num + 1;
+ debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems);
+ let old = c.num_rems in
+ c.rems <- List.fold_left (try_change_cond c) [] c.rems;
+ let curr = c.num_rems in
+ let continue =
+ match OPT.limit_tunneling with
+ | Some n -> curr < old && c.iter_num < n
+ | None -> curr < old
+ in
+ if continue
+ then repeat_change_cond c
+
+
+ (*********************************************)
+ (*** START: printing and debugging functions *)
+
+ let print_cfg (f: LANG.funct) c =
+ let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in
+ Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a;
+ let ep = P.to_int (FUNS.fn_entrypoint f) in
+ debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep);
+ let println = Array.fold_left (FUNS.print_code_unit c) false a in
+ (if println then debug "\n");debug "remaining cond:";
+ List.iter (fun n -> debug "%d " (lab_i n)) c.rems;
+ debug "\n"
+
+
+ (*************************************************************)
+ (* Copy-paste of the extracted code of the verifier *)
+ (* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)
+
+ (** val check_code : coq_UF -> code -> unit res **)
+
+ let check_code td c =
+ PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (())
+
+ (*** END: copy-paste & debugging functions *******)
+
+ (* compute the final distance of each nop nodes to its target *)
+ let final_export f c =
+ let count = ref 0 in
+ let filter_nops_init_dist _ n acc =
+ let tn = target c n in
+ if tn == n
+ then (
+ n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *)
+ acc
+ ) else (
+ n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *)
+ count := !count+1;
+ n::acc
+ )
+ in
+ let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in
+ let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in
+ debug "* %sTunneling.branch_target: initial number of nops = %d\n" OPT.langname !nopcounter;
+ debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n" OPT.langname !count;
+ res
+
+ let branch_target f =
+ debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname;
+ if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n";
+ let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in
+ c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) [];
+ repeat_change_cond c;
+ let res = final_export f c in
+ if !OPT.debug_flag then (
+ try
+ check_code res (FUNS.fn_code f);
+ if OPT.final_dump then print_cfg f c;
+ with e -> (
+ print_cfg f c;
+ check_code res (FUNS.fn_code f)
+ )
+ );
+ res
+ end
end