aboutsummaryrefslogtreecommitdiffstats
path: root/backend/CSE3analysisaux.ml
blob: efe6b6000852bb9e14a515871ef81bdc561e5605 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
(* *************************************************************)
(*                                                             *)
(*             The Compcert verified compiler                  *)
(*                                                             *)
(*           David Monniaux     CNRS, VERIMAG                  *)
(*                                                             *)
(*  Copyright VERIMAG. All rights reserved.                    *)
(*  This file is distributed under the terms of the INRIA      *)
(*  Non-Commercial License Agreement.                          *)
(*                                                             *)
(* *************************************************************)

open CSE3analysis
open Maps
open HashedSet
open Camlcoq
open Coqlib
   
type flattened_equation_or_condition =
  | Flat_equ of int * sym_op * int list
  | Flat_cond of Op.condition * int list;;

let flatten_eq = function
  | Equ(lhs, sop, args) ->
     Flat_equ((P.to_int lhs), sop, (List.map P.to_int args))
  | Cond(cond, args) ->
     Flat_cond(cond, (List.map P.to_int args));;

let imp_add_i_j s i j =
  s := PMap.set i (PSet.add j (PMap.get i !s)) !s;;

let string_of_chunk = function
    | AST.Mint8signed -> "int8signed"
    | AST.Mint8unsigned -> "int8unsigned"
    | AST.Mint16signed -> "int16signed"
    | AST.Mint16unsigned -> "int16unsigned"
    | AST.Mint32 -> "int32"
    | AST.Mint64 -> "int64"
    | AST.Mfloat32 -> "float32"
    | AST.Mfloat64 -> "float64"
    | AST.Many32 -> "any32"
    | AST.Many64 -> "any64";;

let print_reg channel i =
  Printf.fprintf channel "r%d" i;;

let print_eq channel (lhs, sop, args) =
  match sop with
  | SOp op ->
     Printf.printf "%a = %a" print_reg lhs (PrintOp.print_operation print_reg) (op, args)
  | SLoad(chunk, addr) ->
     Printf.printf "%a = %s @ %a" print_reg lhs (string_of_chunk chunk)
       (PrintOp.print_addressing print_reg) (addr, args);;

let print_cond channel (cond, args) =
  Printf.printf "cond %a" (PrintOp.print_condition print_reg) (cond, args);;

let pp_intset oc s =
  Printf.fprintf oc "{ ";
  List.iter (fun i -> Printf.fprintf oc "%d; " (P.to_int i)) (PSet.elements s);
  Printf.fprintf oc "}";;

let pp_rhs oc (sop, args) =
  match sop with
  | SOp op -> PrintOp.print_operation PrintRTL.reg oc (op, args)
  | SLoad(chunk, addr) ->
     Printf.fprintf oc "%s[%a]"
       (PrintAST.name_of_chunk chunk)
         (PrintOp.print_addressing PrintRTL.reg) (addr, args);;

let pp_eq oc eq_cond =
  match eq_cond with
  | Equ(lhs, sop, args) ->
     Printf.fprintf oc "x%d = %a" (P.to_int lhs)
       pp_rhs (sop, args)
  | Cond(cond, args) ->
     Printf.fprintf oc "cond %a"
       (PrintOp.print_condition PrintRTL.reg) (cond, args);;

let pp_P oc x = Printf.fprintf oc "%d" (P.to_int x)
              
let pp_option pp oc = function
  | None -> output_string oc "none"
  | Some x -> pp oc x;;

let is_trivial = function
  | Equ(lhs, (SOp Op.Omove), [lhs']) -> lhs=lhs'
  | _ -> false;;

let rec pp_list separator pp_item chan = function
  | [] -> ()
  | [h] -> pp_item chan h
  | h::t ->
     pp_item chan h;
     output_string chan separator;
     pp_list separator pp_item chan t;;

let pp_set separator pp_item chan s =
  pp_list separator pp_item chan (PSet.elements s);;

let pp_equation hints chan x =
  match PTree.get x hints.hint_eq_catalog with
  | None -> output_string chan "???"
  | Some eq ->
     match eq with
     | Equ(lhs, sop, args) ->
        print_eq chan  (P.to_int lhs, sop, List.map P.to_int args)
     | Cond(cond, args) ->
        print_cond chan (cond, List.map P.to_int args);;

let pp_relation hints chan rel =
  pp_set "; " (pp_equation hints) chan rel;;

let pp_relation_b hints chan = function
  | None -> output_string chan "bot"
  | Some rel -> pp_relation hints chan rel;;

let pp_results f (invariants : RB.t PMap.t) hints chan =
  let max_pc = P.to_int (RTL.max_pc_function f) in
  for pc=max_pc downto 1
  do
    Printf.fprintf chan "%d: %a\n\n" pc
      (pp_relation_b hints) (PMap.get (P.of_int pc) invariants)
  done

module IntSet=Set.Make(struct type t=int let compare = ( - ) end);;

let rec union_list prev = function
  | [] -> prev
  | h::t -> union_list (RB.lub prev h) t;;

let rb_glb (x : RB.t) (y : RB.t) : RB.t =
  match x, y with
  | None, _ | _, None -> None
  | (Some x'), (Some y') -> Some (RELATION.glb x' y');;

let compute_invariants
      (nodes : RTL.node list)
      (entrypoint : RTL.node)
      (tfr : RTL.node -> RB.t -> (RTL.node * RB.t) list) =
  let todo = ref IntSet.empty
  and invariants = ref (PMap.set entrypoint (Some RELATION.top) (PMap.init RB.bot)) in  
  let add_todo (pc : RTL.node) =
    todo := IntSet.add (P.to_int pc) !todo in 
  let update_node (pc : RTL.node) =
    (if !Clflags.option_debug_compcert > 9
     then Printf.printf "UP updating node %d\n" (P.to_int pc));
    let cur = PMap.get pc !invariants in
    List.iter (fun (next_pc, next_contrib) ->
        let previous = PMap.get next_pc !invariants in
        let next = RB.lub previous next_contrib in
        if not (RB.beq previous next)
        then (
          invariants := PMap.set next_pc next !invariants;
          add_todo next_pc)) (tfr pc cur) in
  add_todo entrypoint;
  while not (IntSet.is_empty !todo) do
    let nxt = IntSet.max_elt !todo in
    todo := IntSet.remove nxt !todo;
    update_node (P.of_int nxt)
  done;
  !invariants;;

let refine_invariants
      (nodes : RTL.node list)
      (entrypoint : RTL.node)
      (successors : RTL.node -> RTL.node list)
      (predecessors : RTL.node -> RTL.node list)
      (tfr : RTL.node -> RB.t -> (RTL.node * RB.t) list)
      (invariants0 : RB.t PMap.t) =
  let todo = ref IntSet.empty
  and invariants = ref invariants0 in  
  let add_todo (pc : RTL.node) =
    todo := IntSet.add (P.to_int pc) !todo in 
  let update_node (pc : RTL.node) =
    (if !Clflags.option_debug_compcert > 9
     then Printf.printf "DOWN updating node %d\n" (P.to_int pc));
    if not (peq pc entrypoint)
    then
      let cur = PMap.get pc !invariants in
      let nxt = union_list RB.bot
                  (List.map
                     (fun pred_pc->
                       rb_glb cur
                         (List.assoc pc (tfr pred_pc (PMap.get pred_pc !invariants))))
                     (predecessors pc)) in
      if not (RB.beq cur nxt)
      then
        begin
          (if !Clflags.option_debug_compcert > 4
           then Printf.printf "refining CSE3 node %d\n" (P.to_int pc));
          List.iter add_todo (successors pc)
        end in
  (List.iter add_todo nodes);
  while not (IntSet.is_empty !todo) do
    let nxt = IntSet.max_elt !todo in
    todo := IntSet.remove nxt !todo;
    update_node (P.of_int nxt)
  done;
  !invariants;;

let get_default default x ptree =
  match PTree.get x ptree with
  | None -> default
  | Some y -> y;;

let initial_analysis ctx tenv (f : RTL.coq_function) =
  let tfr = apply_instr' ctx tenv f.RTL.fn_code in
  compute_invariants
    (List.map fst (PTree.elements f.RTL.fn_code))
    f.RTL.fn_entrypoint tfr;;

let refine_analysis ctx tenv
      (f : RTL.coq_function) (invariants0 : RB.t PMap.t) =
  let succ_map = RTL.successors_map f in
  let succ_f x = get_default [] x succ_map in
  let pred_map = Kildall.make_predecessors f.RTL.fn_code RTL.successors_instr in
  let pred_f x = get_default [] x pred_map in
  let tfr = apply_instr' ctx tenv f.RTL.fn_code in
  refine_invariants
    (List.map fst (PTree.elements f.RTL.fn_code))
    f.RTL.fn_entrypoint succ_f pred_f tfr invariants0;;

let add_to_set_in_table table key item =
  Hashtbl.add table key
    (PSet.add item
       (match Hashtbl.find_opt table key with
        | None -> PSet.empty
        | Some s -> s));;
  
let preanalysis (tenv : typing_env) (f : RTL.coq_function) =
  let cur_eq_id = ref 0
  and cur_catalog = ref PTree.empty
  and eq_table = Hashtbl.create 100
  and rhs_table = Hashtbl.create 100
  and cur_kill_reg = ref (PMap.init PSet.empty)
  and cur_kill_mem = ref PSet.empty
  and cur_kill_store = ref PSet.empty
  and cur_moves = ref (PMap.init PSet.empty) in
  let eq_find_oracle node eq =
    assert (not (is_trivial eq));
    let o = Hashtbl.find_opt eq_table (flatten_eq eq) in
    (* FIXME (if o = None then failwith "eq_find_oracle"); *)
    (if !Clflags.option_debug_compcert > 5
     then Printf.printf "@%d: eq_find %a -> %a\n" (P.to_int node)
            pp_eq eq (pp_option pp_P) o);
    o
  and rhs_find_oracle node sop args =
    let o =
      match Hashtbl.find_opt rhs_table (sop, List.map P.to_int args) with
      | None -> PSet.empty
      | Some s -> s in
    (if !Clflags.option_debug_compcert > 5
     then Printf.printf "@%d: rhs_find %a = %a\n"
            (P.to_int node) pp_rhs (sop, args) pp_intset o);
    o in
  let mutating_eq_find_oracle node eq : P.t option =
    let flat_eq = flatten_eq eq in
    let o =
    match Hashtbl.find_opt eq_table flat_eq with
    | Some x ->
       Some x
    | None ->
       (* TODO print_eq stderr flat_eq; *)
       incr cur_eq_id;
       let id = !cur_eq_id in
       let coq_id = P.of_int id in
       begin
         Hashtbl.add eq_table flat_eq coq_id;
         (cur_catalog := PTree.set coq_id eq !cur_catalog);
         (match flat_eq with
          | Flat_equ(flat_eq_lhs, flat_eq_op, flat_eq_args) ->
             add_to_set_in_table rhs_table
               (flat_eq_op, flat_eq_args) coq_id
          | Flat_cond(flat_eq_cond, flat_eq_args) -> ());
         (match eq with
          | Equ(lhs, sop, args) ->
             List.iter
               (fun reg -> imp_add_i_j cur_kill_reg reg coq_id)
               (lhs :: args);
             (match sop, args with
              | (SOp Op.Omove), [rhs] -> imp_add_i_j cur_moves lhs coq_id
              | _, _ -> ())
          | Cond(cond, args) ->
             List.iter
               (fun reg -> imp_add_i_j cur_kill_reg reg coq_id) args
         );
         (if eq_cond_depends_on_mem eq
          then cur_kill_mem := PSet.add coq_id !cur_kill_mem);
         (if eq_cond_depends_on_store eq
          then cur_kill_store := PSet.add coq_id !cur_kill_store);
         Some coq_id
       end
    in
    (if !Clflags.option_debug_compcert > 5
     then Printf.printf "@%d: mutating_eq_find %a -> %a\n" (P.to_int node)
      pp_eq eq (pp_option pp_P) o);    
    o
  in
  let ctx = { eq_catalog     = (fun eq_id -> PTree.get eq_id !cur_catalog);
              eq_find_oracle = mutating_eq_find_oracle;
              eq_rhs_oracle  = rhs_find_oracle ;
              eq_kill_reg    = (fun reg -> PMap.get reg !cur_kill_reg);
              eq_kill_mem    = (fun () -> !cur_kill_mem);
              eq_kill_store  = (fun () -> !cur_kill_store);
              eq_moves       = (fun reg -> PMap.get reg !cur_moves)
            } in
  let invariants = initial_analysis ctx tenv f in
  let invariants' =
    if ! Clflags.option_fcse3_refine
    then refine_analysis ctx tenv f invariants
    else invariants
  and hints = { hint_eq_catalog    = !cur_catalog;
                hint_eq_find_oracle= eq_find_oracle;
                hint_eq_rhs_oracle = rhs_find_oracle } in
  (if !Clflags.option_debug_compcert > 1
   then pp_results f invariants' hints stdout);
  invariants', hints
;;