aboutsummaryrefslogtreecommitdiffstats
path: root/backend/RTLTunnelingaux.ml
blob: a30b43cfb31b45da12d8ab2d35e64f09ef3f6779 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
(* *************************************************************)
(*                                                             *)
(*             The Compcert verified compiler                  *)
(*                                                             *)
(*           Sylvain Boulmé  Grenoble-INP, VERIMAG             *)
(*           TODO: Proper author information                   *)
(*                                                             *)
(*  Copyright VERIMAG. All rights reserved.                    *)
(*  This file is distributed under the terms of the INRIA      *)
(*  Non-Commercial License Agreement.                          *)
(*                                                             *)
(* *************************************************************)

(*

This file implements the [branch_target] oracle that identifies "nop" branches in a RTL function,
and computes their target node with the distance (ie the number of cummulated nops) toward this target.

See [RTLTunneling.v]

*)

open Coqlib
open RTL
open Maps
open Camlcoq

let limit_tunneling = None (* for debugging: [Some x] limit the number of iterations *)
let debug_flag = ref true
let final_dump = false   (* set to true to have a more verbose debugging *)

let debug fmt =
  if !debug_flag then Printf.eprintf fmt
  else Printf.ifprintf stderr fmt

exception BugOnPC of int

let nopcounter = ref 0

(* type of labels in the cfg *)
type label = int * P.t

(* instructions under analyzis *)
type simple_inst = (* a simplified view of RTL instructions *)
  INOP of node
| ICOND of node * node
| OTHER
and node = {
    lab : label;
    mutable inst: simple_inst;
    mutable link: node; (* link in the union-find: itself for non "nop"-nodes, target of the "nop" otherwise *)
    mutable dist: int;
    mutable tag: int
  }

(* type of the (simplified) CFG *)
type cfg = {
    nodes: (int, node) Hashtbl.t;
    mutable rems: node list; (* remaining conditions that may become lbranch or not *)
    mutable num_rems: int;
    mutable iter_num: int (* number of iterations in elimination of conditions *)
  }

let lab_i (n: node): int = fst n.lab
let lab_p (n: node): P.t = snd n.lab

let rec target c n = (* inspired from the "find" of union-find algorithm *)
  match n.inst with
  | ICOND(s1,s2) ->
     if n.link != n
     then update c n
     else if n.tag < c.iter_num then (
       (* we try to change the condition ... *)
       n.tag <- c.iter_num; (* ... but at most once by iteration *)
       let ts1 = target c s1 in
       let ts2 = target c s2 in
       if ts1 == ts2 then (n.link <- ts1; ts1) else n
     ) else n
  | _ ->
     if n.link != n
     then update c n
     else n
and update c n =
  let t = target c n.link in
  n.link <- t; t

let get_node c p =
  let li = P.to_int p in
  try
    Hashtbl.find c.nodes li
  with
    Not_found ->
      let rec n = { lab = (li, p); inst = OTHER; link = n ; dist = 0; tag = 0 }  in
      Hashtbl.add c.nodes li n;
      n

let set_branch c p s =
  let li = P.to_int p in
  try
    let n = Hashtbl.find c.nodes li in
    n.inst <- INOP s;
    n.link <- target c s
  with
    Not_found ->
      let n = { lab = (li,p); inst = INOP s; link = target c s; dist = 0; tag = 0 } in
      Hashtbl.add c.nodes li n


(* build [c.nodes] and accumulate conditions in [acc] *)
let build_simplified_cfg c acc pc i =
  match i with
  | Inop s ->
     let ns = get_node c s in
     set_branch c pc ns;
     incr nopcounter;
     acc
  | Icond (_, _, s1, s2, _) ->
     c.num_rems <- c.num_rems + 1;
     let ns1 = get_node c s1 in
     let ns2 = get_node c s2 in
     let npc = get_node c pc in
     npc.inst <- ICOND(ns1, ns2);
     npc::acc
  | _ -> acc

(* try to change a condition into a branch
[acc] is the current accumulator of conditions to consider in the next iteration of repeat_change_cond
*)
let try_change_cond c acc pc =
  match pc.inst with
  | ICOND(s1,s2) ->
     let ts1 = target c s1 in
     let ts2 = target c s2 in
     if ts1 == ts2 then (
       pc.link <- ts1;
       c.num_rems <- c.num_rems - 1;
       acc
     ) else
       pc::acc
  | _  -> raise (BugOnPC (lab_i pc)) (* ICOND expected *)

(* repeat [try_change_cond] until no condition is changed into a branch *)
let rec repeat_change_cond c =
  c.iter_num <- c.iter_num + 1;
  debug "++ RTLTunneling.branch_target %d: remaining number of conds to consider = %d\n" (c.iter_num) (c.num_rems);
  let old = c.num_rems in
  c.rems <- List.fold_left (try_change_cond c) [] c.rems;
  let curr = c.num_rems in
  let continue =
    match limit_tunneling with
    | Some n -> curr < old && c.iter_num < n
    | None -> curr < old
  in
  if continue
  then repeat_change_cond c


(* compute the final distance of each nop nodes to its target *)
let undef_dist = -1
let self_dist = undef_dist-1
let rec dist n =
  if n.dist = undef_dist
  then (
    n.dist <- self_dist; (* protection against an unexpected loop in the data-structure *)
    n.dist <-
      (match n.inst with
       | OTHER -> 0
       | INOP p -> 1 + dist p
       | ICOND (p1,p2) -> 1 + (max (dist p1) (dist p2)));
    n.dist
  ) else if n.dist=self_dist then raise (BugOnPC (lab_i n))
    else n.dist

let final_export f c =
  let count = ref 0 in
  let filter_nops_init_dist _ n acc =
    let tn = target c n in
    if tn == n
    then (
      n.dist <- 0; (* force [n] to be a base case in the recursion of [dist] *)
      acc
    ) else (
      n.dist <- undef_dist; (* force [dist] to compute the actual [n.dist] *)
      count := !count+1;
      n::acc
    )
  in
  let nops = Hashtbl.fold filter_nops_init_dist c.nodes [] in
  let res = List.fold_left (fun acc n -> PTree.set (lab_p n) (lab_p n.link, Z.of_uint (dist n)) acc) PTree.empty nops in
  debug "* RTLTunneling.branch_target: [stats] initial number of nops = %d\n" !nopcounter;
  debug "* RTLTunneling.branch_target: [stats] final number of eliminated nops = %d\n" !count;
  res

(*********************************************)
(*** START: printing and debugging functions *)

let string_of_labeli nodes ipc =
  try
    let pc = Hashtbl.find nodes ipc in
    if pc.link == pc
    then Printf.sprintf "(Target@%d)" (dist pc)
    else Printf.sprintf "(Nop %d @%d)" (lab_i pc.link) (dist pc)
  with
    Not_found -> ""

let print_instr c println (pc, i) =
  match i with
  | Inop s -> (if println then debug "\n"); debug "%d:Inop %d %s\n" pc (P.to_int s) (string_of_labeli c.nodes pc); false
  | Icond (_, _, s1, s2, _) -> (if println then debug "\n"); debug "%d:Icond (%d,%d) %s\n" pc (P.to_int s1) (P.to_int s2) (string_of_labeli c.nodes pc); false
  | _ -> debug "%d " pc; true


let print_cfg f c  =
  let a = Array.of_list (PTree.fold (fun acc pc i -> (P.to_int pc,i)::acc) f.fn_code []) in
  Array.fast_sort (fun (i1,_) (i2,_) -> i2 - i1) a;
  let ep = P.to_int f.fn_entrypoint in
  debug "entrypoint: %d %s\n" ep (string_of_labeli c.nodes ep);
  let println = Array.fold_left (print_instr c) false a in
  (if println then debug "\n");debug "remaining cond:";
  List.iter (fun n -> debug "%d " (lab_i n)) c.rems;
  debug "\n"

(*************************************************************)
(* Copy-paste of the extracted code of the verifier          *)
(* with [raise (BugOnPC (P.to_int pc))] instead of [Error.*] *)

let get td pc =
  match PTree.get pc td with
  | Some p -> let (t0, d) = p in (t0, d)
  | None -> (pc, Z.of_uint 0)

let check_instr td pc i =
  match PTree.get pc td with
  | Some p ->
      let (tpc, dpc) = p in
      let dpc0 = dpc in begin
        match i with
        | Inop s ->
          let (ts, ds) = get td s in
          if peq tpc ts
          then if zlt ds dpc0
               then ()
               else raise (BugOnPC (P.to_int pc))
          else raise (BugOnPC (P.to_int pc))
        | Icond (_, _, s1, s2, _) ->
          let (ts1, ds1) = get td s1 in
          let (ts2, ds2) = get td s2 in
          if peq tpc ts1
          then if peq tpc ts2
               then if zlt ds1 dpc0
                    then if zlt ds2 dpc0
                         then ()
                         else raise (BugOnPC (P.to_int pc))
                    else raise (BugOnPC (P.to_int pc))
               else raise (BugOnPC (P.to_int pc))
          else raise (BugOnPC (P.to_int pc))
        | _ ->
          raise (BugOnPC (P.to_int pc)) end
  | None -> ()

(** val check_code : coq_UF -> code -> unit res **)

let check_code td c =
  PTree.fold (fun _ pc i -> check_instr td pc i) c (())

(*** END: copy-paste & debugging functions *******)

let branch_target f =
  debug "* RTLTunneling.branch_target: starting on a new function\n";
  if limit_tunneling <> None then debug "* WARNING: limit_tunneling <> None\n";
  let c = { nodes = Hashtbl.create 100; rems = []; num_rems = 0; iter_num = 0 } in
  c.rems <- PTree.fold (build_simplified_cfg c) f.fn_code [];
  repeat_change_cond c;
  let res = final_export f c in
  if !debug_flag then (
    try
      check_code res f.fn_code;
      if final_dump then print_cfg f c;
    with e -> (
      print_cfg f c;
      check_code res f.fn_code
    )
  );
  res