aboutsummaryrefslogtreecommitdiffstats
path: root/backend/Tunnelinglibs.ml
blob: 010595be53654fccfc73bd517627fb7009cba0f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
(* *************************************************************)
(*                                                             *)
(*             The Compcert verified compiler                  *)
(*                                                             *)
(*           Sylvain Boulmé  Grenoble-INP, VERIMAG             *)
(*           Pierre Goutagny ENS-Lyon, VERIMAG                 *)
(*                                                             *)
(*  Copyright VERIMAG. All rights reserved.                    *)
(*  This file is distributed under the terms of the INRIA      *)
(*  Non-Commercial License Agreement.                          *)
(*                                                             *)
(* *************************************************************)

(*

This file implements the core functions of the tunneling passes, for both RTL
and LTL, by using a simplified CFG as a transparent interface

See [LTLTunneling.v]/[LTLTunnelingaux.ml] and [RTLTunneling.v]/[RTLTunnelingaux.ml].

*)

open Maps
open Camlcoq

(* type of labels in the cfg *)
type label = int * P.t

(* instructions under analyzis *)
type simple_inst = (* a simplified view of instructions *)
  BRANCH of node
| COND of node * node
| OTHER
and node = {
    lab: label;
    mutable inst: simple_inst;
    mutable link: node; (* link in the union-find: itself for non "nop"-nodes, target of the "nop" otherwise *)
    mutable dist: int;
    mutable tag: int
  }

type positive = P.t
type integer = Z.t

(* type of the (simplified) CFG *)
type cfg = {
    nodes: (int, node) Hashtbl.t;
    mutable rems: node list; (* remaining conditions that may become lbranch or not *)
    mutable num_rems: int;
    mutable iter_num: int (* number of iterations in elimination of conditions *)
  }

exception BugOnPC of int

(* keeps track of the total number of nops seen, for debugging purposes *)
let nopcounter = ref 0

(* General functions that do not require language-specific context, and that
 are used for building language-specific functions *)

let rec target c n = (* inspired from the "find" of union-find algorithm *)
  match n.inst with
  | COND(s1,s2) ->
     if n.link != n
     then update c n
     else if n.tag < c.iter_num then (
       (* we try to change the condition ... *)
       n.tag <- c.iter_num; (* ... but at most once by iteration *)
       let ts1 = target c s1 in
       let ts2 = target c s2 in
       if ts1 == ts2 then (n.link <- ts1; ts1) else n
     ) else n
  | _ ->
     if n.link != n
     then update c n
     else n
and update c n =
  let t = target c n.link in
  n.link <- t; t

let get_node c p =
  let li = P.to_int p in
  try
    Hashtbl.find c.nodes li
  with
    Not_found ->
      let rec n = { lab = (li, p); inst = OTHER; link = n ; dist = 0; tag = 0 }  in
      Hashtbl.add c.nodes li n;
      n

let set_branch c p s =
  let li = P.to_int p in
  try
    let n = Hashtbl.find c.nodes li in
    n.inst <- BRANCH s;
    n.link <- target c s
  with
    Not_found ->
      let n = { lab = (li,p); inst = BRANCH s; link = target c s; dist = 0; tag = 0 } in
      Hashtbl.add c.nodes li n

let get td pc =
  match PTree.get pc td with
  | Some p -> let (t0, d) = p in (t0, d)
  | None -> (pc, Z.of_uint 0)

let lab_i (n: node): int = fst n.lab
let lab_p (n: node): P.t = snd n.lab

let undef_dist = -1
let self_dist = undef_dist-1
let rec dist n =
  if n.dist = undef_dist
  then (
    n.dist <- self_dist; (* protection against an unexpected loop in the data-structure *)
    n.dist <-
      (match n.inst with
       | OTHER -> 0
       | BRANCH p -> 1 + dist p
       | COND (p1,p2) -> 1 + (max (dist p1) (dist p2)));
    n.dist
  ) else if n.dist=self_dist then raise (BugOnPC (lab_i n))
    else n.dist

let string_of_labeli nodes ipc =
  try
    let pc = Hashtbl.find nodes ipc in
    if pc.link == pc
    then Printf.sprintf "(Target@%d)" (dist pc)
    else Printf.sprintf "(Nop %d @%d)" (lab_i pc.link) (dist pc)
  with
    Not_found -> ""

(*
 * When given the necessary types and options as context, and then some
 * language-specific functions that cannot be factorised between LTL and RTL, the
 * `Tunneling` functor returns a module containing the corresponding
 * `branch_target` function.
 *)

module Tunneling = functor
  (* Language-specific types *)
  (LANG: sig
    type code_unit (* the type of a node of the code cfg (an instruction or a bblock *)
    type funct (* type of internal functions *)
  end)

  (* Compilation options for debugging *)
  (OPT: sig
    val langname: string
    val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *)
    val debug_flag: bool ref
    val final_dump: bool (* set to true to have a more verbose debugging *)
  end)
  -> struct

  (* The `debug` function uses values from `OPT`, and is used in functions passed to `F`
     so it must be defined between the two *)
  let debug fmt =
    if !OPT.debug_flag then Printf.eprintf fmt
    else Printf.ifprintf stderr fmt

  module T
    (* Language-specific functions *)
    (FUNS: sig
      (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *)
      val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list
      val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool
      val fn_code: LANG.funct -> LANG.code_unit PTree.t
      val fn_entrypoint: LANG.funct -> positive
      val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit
    end)
    (* only export what's needed *)
    : sig val branch_target: LANG.funct -> (positive * integer) PTree.t end
  = struct

    (* try to change a condition into a branch [acc] is the current accumulator of
       conditions to consider in the next iteration of repeat_change_cond *)
    let try_change_cond c acc pc =
      match pc.inst with
      | COND(s1,s2) ->
         let ts1 = target c s1 in
         let ts2 = target c s2 in
         if ts1 == ts2 then (
           pc.link <- ts1;
           c.num_rems <- c.num_rems - 1;
           acc
         ) else
           pc::acc
      | _  -> raise (BugOnPC (lab_i pc)) (* COND expected *)

    (* repeat [try_change_cond] until no condition is changed into a branch *)
    let rec repeat_change_cond c =
      c.iter_num <- c.iter_num + 1;
      debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems);
      let old =  c.num_rems in
      c.rems <- List.fold_left (try_change_cond c) [] c.rems;
      let curr = c.num_rems in
      let continue =
        match OPT.limit_tunneling with
        | Some n -> curr < old && c.iter_num < n
        | None -> curr < old
      in
      if continue
      then repeat_change_cond c


    (*********************************************)
    (*** START: printing and debugging functions *)

    let print_cfg (f: LANG.funct) c  =
      let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in
      Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a;
      let ep = P.to_int (FUNS.fn_entrypoint f) in
      debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep);
      let println = Array.fold_left (FUNS.print_code_unit c) false a in
      (if println then debug "\n");debug "remaining cond:";
      List.iter (fun n -> debug "%d " (lab_i n)) c.rems;
      debug "\n"


    (*************************************************************)
    (* Copy-paste of the extracted code of the verifier          *)
    (* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)

    (** val check_code : coq_UF -> code -> unit res **)

    let check_code td c =
      PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (())

    (*** END: copy-paste & debugging functions *******)

    (* compute the final distance of each nop nodes to its target *)
    let final_export f c =
      let count = ref 0 in
      let filter_nops_init_dist _ n acc =
        let tn = target c n in
        if tn == n
        then (
          n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *)
          acc
        ) else (
          n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *)
          count := !count+1;
          n::acc
        )
      in
      let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in
      let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in
      debug "* %sTunneling.branch_target: initial number of nops = %d\n" OPT.langname !nopcounter;
      debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n" OPT.langname !count;
      res

    let branch_target f =
      debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname;
      if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n";
      let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in
      c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) [];
      repeat_change_cond c;
      let res = final_export f c in
      if !OPT.debug_flag then (
        try
          check_code res (FUNS.fn_code f);
          if OPT.final_dump then print_cfg f c;
        with e -> (
          print_cfg f c;
          check_code res (FUNS.fn_code f)
        )
      );
      res
  end
end