aboutsummaryrefslogtreecommitdiffstats
path: root/backend/Tunnelinglibs.ml
blob: e1e61d68153b432043b9bdf0f82a67bbb33cf88a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
(* *************************************************************)
(*                                                             *)
(*             The Compcert verified compiler                  *)
(*                                                             *)
(*           Sylvain Boulmé  Grenoble-INP, VERIMAG             *)
(*           TODO: Proper author information                   *)
(*                                                             *)
(*  Copyright VERIMAG. All rights reserved.                    *)
(*  This file is distributed under the terms of the INRIA      *)
(*  Non-Commercial License Agreement.                          *)
(*                                                             *)
(* *************************************************************)

(*

This file implements the core functions of the tunneling passes, for both RTL
and LTL, by using a simplified CFG as a transparent interface

See [Tunneling.v] and [RTLTunneling.v]

*)

open Maps
open Camlcoq

(* type of labels in the cfg *)
type label = int * P.t

(* instructions under analyzis *)
type simple_inst = (* a simplified view of instructions *)
  BRANCH of node
| COND of node * node
| OTHER
and node = {
    lab: label;
    mutable inst: simple_inst;
    mutable link: node; (* link in the union-find: itself for non "nop"-nodes, target of the "nop" otherwise *)
    mutable dist: int;
    mutable tag: int
  }

type cfg_node = (int, node) Hashtbl.t

type positive = P.t
type integer = Z.t

(* type of the (simplified) CFG *)
type cfg = {
    nodes: cfg_node;
    mutable rems: node list; (* remaining conditions that may become lbranch or not *)
    mutable num_rems: int;
    mutable iter_num: int (* number of iterations in elimination of conditions *)
  }

module Tunneling = functor
  (LANG: sig
    type code_unit (* the type of a node of the code cfg (an instruction or a bblock *)
    type funct
  end)
  (OPT: sig
    val langname: string
    val limit_tunneling: int option (* for debugging: [Some x] limit the number of iterations *)
    val debug_flag: bool ref
    val final_dump: bool (* set to true to have a more verbose debugging *)
  end)
  (FUNS: sig
    (* build [c.nodes] and accumulate in [acc] conditions at beginning of LTL basic-blocks *)
    val build_simplified_cfg: cfg -> node list -> positive -> LANG.code_unit -> node list

    val print_code_unit: cfg -> bool -> int * LANG.code_unit -> bool

    val fn_code: LANG.funct -> LANG.code_unit PTree.t
    val fn_entrypoint: LANG.funct -> positive

    val check_code_unit: (positive * integer) PTree.t -> positive -> LANG.code_unit -> unit
  end)
-> struct

let debug fmt =
  if !OPT.debug_flag then Printf.eprintf fmt
  else Printf.ifprintf stderr fmt

exception BugOnPC of int

let lab_i (n: node): int = fst n.lab
let lab_p (n: node): P.t = snd n.lab

let rec target c n = (* inspired from the "find" of union-find algorithm *)
  match n.inst with
  | COND(s1,s2) ->
     if n.link != n
     then update c n
     else if n.tag < c.iter_num then (
       (* we try to change the condition ... *)
       n.tag <- c.iter_num; (* ... but at most once by iteration *)
       let ts1 = target c s1 in
       let ts2 = target c s2 in
       if ts1 == ts2 then (n.link <- ts1; ts1) else n
     ) else n
  | _ ->
     if n.link != n
     then update c n
     else n
and update c n =
  let t = target c n.link in
  n.link <- t; t

let get_node c p =
  let li = P.to_int p in
  try
    Hashtbl.find c.nodes li
  with
    Not_found ->
      let rec n = { lab = (li, p); inst = OTHER; link = n ; dist = 0; tag = 0 }  in
      Hashtbl.add c.nodes li n;
      n

let set_branch c p s =
  let li = P.to_int p in
  try
    let n = Hashtbl.find c.nodes li in
    n.inst <- BRANCH s;
    n.link <- target c s
  with
    Not_found ->
      let n = { lab = (li,p); inst = BRANCH s; link = target c s; dist = 0; tag = 0 } in
      Hashtbl.add c.nodes li n



(* try to change a condition into a branch
[acc] is the current accumulator of conditions to consider in the next iteration of repeat_change_cond
*)
let try_change_cond c acc pc =
  match pc.inst with
  | COND(s1,s2) ->
     let ts1 = target c s1 in
     let ts2 = target c s2 in
     if ts1 == ts2 then (
       pc.link <- ts1;
       c.num_rems <- c.num_rems - 1;
       acc
     ) else
       pc::acc
  | _  -> raise (BugOnPC (lab_i pc)) (* COND expected *)

(* repeat [try_change_cond] until no condition is changed into a branch *)
let rec repeat_change_cond c =
  c.iter_num <- c.iter_num + 1;
  debug "++ %sTunneling.branch_target %d: remaining number of conds to consider = %d\n" OPT.langname (c.iter_num) (c.num_rems);
  let old =  c.num_rems in
  c.rems <- List.fold_left (try_change_cond c) [] c.rems;
  let curr = c.num_rems in
  let continue =
    match OPT.limit_tunneling with
    | Some n -> curr < old && c.iter_num < n
    | None -> curr < old
  in
  if continue
  then repeat_change_cond c


(* compute the final distance of each nop nodes to its target *)
let undef_dist = -1
let self_dist = undef_dist-1
let rec dist n =
  if n.dist = undef_dist
  then (
    n.dist <- self_dist; (* protection against an unexpected loop in the data-structure *)
    n.dist <-
      (match n.inst with
       | OTHER -> 0
       | BRANCH p -> 1 + dist p
       | COND (p1,p2) -> 1 + (max (dist p1) (dist p2)));
    n.dist
  ) else if n.dist=self_dist then raise (BugOnPC (lab_i n))
    else n.dist

let final_export f c =
  let count = ref 0 in
  let filter_nops_init_dist _ n acc =
    let tn = target c n in
    if tn == n
    then (
      n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *)
      acc
    ) else (
      n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *)
      count := !count+1;
      n::acc
    )
  in
  let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in
  let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in
  debug "* %sTunneling.branch_target: final number of eliminated nops = %d\n"
  OPT.langname !count;
  res

(*********************************************)
(*** START: printing and debugging functions *)

let string_of_labeli nodes ipc =
  try
    let pc = Hashtbl.find nodes ipc in
    if pc.link == pc
    then Printf.sprintf "(Target@%d)" (dist pc)
    else Printf.sprintf "(Nop %d @%d)" (lab_i pc.link) (dist pc)
  with
    Not_found -> ""

let print_cfg (f: LANG.funct) c  =
  let a = Array.of_list (PTree.fold (fun acc pc cu -> (P.to_int pc,cu)::acc) (FUNS.fn_code f) []) in
  Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a;
  let ep = P.to_int (FUNS.fn_entrypoint f) in
  debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep);
  let println = Array.fold_left (FUNS.print_code_unit c) false a in
  (if println then debug "\n");debug "remaining cond:";
  List.iter (fun n -> debug "%d " (lab_i n)) c.rems;
  debug "\n"

(*************************************************************)
(* Copy-paste of the extracted code of the verifier          *)
(* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)

(** val check_code : coq_UF -> code -> unit res **)

let check_code td c =
  PTree.fold (fun _ pc cu -> FUNS.check_code_unit td pc cu) c (())

(*** END: copy-paste & debugging functions *******)

let branch_target f =
  debug "* %sTunneling.branch_target: starting on a new function\n" OPT.langname;
  if OPT.limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n";
  let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in
  c.rems <- PTree.fold (FUNS.build_simplified_cfg c) (FUNS.fn_code f) [];
  repeat_change_cond c;
  let res = final_export f c in
  if !OPT.debug_flag then (
    try
      check_code res (FUNS.fn_code f);
      if OPT.final_dump then print_cfg f c;
    with e -> (
      print_cfg f c;
      check_code res (FUNS.fn_code f)
    )
  );
  res

end