aboutsummaryrefslogtreecommitdiffstats
path: root/backend/Linearizeaux.ml
blob: 402e376d7cf86c7114bd6802582c3b817bc3e2f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
(* *********************************************************************)
(*                                                                     *)
(*              The Compcert verified compiler                         *)
(*                                                                     *)
(*          Xavier Leroy, INRIA Paris-Rocquencourt                     *)
(*                                                                     *)
(*  Copyright Institut National de Recherche en Informatique et en     *)
(*  Automatique.  All rights reserved.  This file is distributed       *)
(*  under the terms of the INRIA Non-Commercial License Agreement.     *)
(*                                                                     *)
(* *********************************************************************)

open LTL
open Maps

let debug_flag = ref false

let debug fmt =
  if !debug_flag then Printf.eprintf fmt
  else Printf.ifprintf stderr fmt

(* Trivial enumeration, in decreasing order of PC *)

(***
let enumerate_aux f reach =
  positive_rec
    Coq_nil
    (fun pc nodes ->
      if PMap.get pc reach
      then Coq_cons (pc, nodes)
      else nodes)
    f.fn_nextpc
***)

(* More clever enumeration that flattens basic blocks *)

open Camlcoq

module IntSet = Set.Make(struct type t = int let compare = compare end)

(* Determine join points: reachable nodes that have > 1 predecessor *)

let join_points f =
  let reached = ref IntSet.empty in
  let reached_twice = ref IntSet.empty in
  let rec traverse pc =
    let npc = P.to_int pc in
    if IntSet.mem npc !reached then begin
      if not (IntSet.mem npc !reached_twice) then
        reached_twice := IntSet.add npc !reached_twice
    end else begin
      reached := IntSet.add npc !reached;
      match PTree.get pc f.fn_code with
      | None -> ()
      | Some b -> traverse_succs (successors_block b)
    end
  and traverse_succs = function
    | [] -> ()
    | [pc] -> traverse pc
    | pc :: l -> traverse pc; traverse_succs l
  in traverse f.fn_entrypoint; !reached_twice

(* Cut into reachable basic blocks, annotated with the min value of the PC *)

let basic_blocks f joins =
  let blocks = ref [] in
  let visited = ref IntSet.empty in
  (* start_block:
       pc is the function entry point
          or a join point
          or the successor of a conditional test *)
  let rec start_block pc =
    let npc = P.to_int pc in
    if not (IntSet.mem npc !visited) then begin
      visited := IntSet.add npc !visited;
      in_block [] max_int pc
    end
  (* in_block: add pc to block and check successors *)
  and in_block blk minpc pc =
    let npc = P.to_int pc in
    let blk = pc :: blk in
    let minpc = min npc minpc in
    match PTree.get pc f.fn_code with
    | None -> assert false
    | Some b ->
       let rec do_instr_list = function
       | [] -> assert false
       | Lbranch s :: _ -> next_in_block blk minpc s
       | Ltailcall (sig0, ros) :: _ -> end_block blk minpc
       | Lcond (cond, args, ifso, ifnot, _) :: _ ->
             end_block blk minpc; start_block ifso; start_block ifnot
       | Ljumptable(arg, tbl) :: _ ->
             end_block blk minpc; List.iter start_block tbl
       | Lreturn :: _ -> end_block blk minpc
       | instr :: b' -> do_instr_list b' in
       do_instr_list b
  (* next_in_block: check if join point and either extend block
     or start block *)
  and next_in_block blk minpc pc =
    let npc = P.to_int pc in
    if IntSet.mem npc joins
    then (end_block blk minpc; start_block pc)
    else in_block blk minpc pc
  (* end_block: record block that we just discovered *)
  and end_block blk minpc =
    blocks := (minpc, List.rev blk) :: !blocks
  in
    start_block f.fn_entrypoint; !blocks

(* Flatten basic blocks in decreasing order of minpc *)

let flatten_blocks blks =
  let cmp_minpc (mpc1, _) (mpc2, _) =
    if mpc1 = mpc2 then 0 else if mpc1 > mpc2 then -1 else 1
  in
    List.flatten (List.map snd (List.sort cmp_minpc blks))

(* Build the enumeration *)

let enumerate_aux_flat f reach =
  flatten_blocks (basic_blocks f (join_points f))

(**
 * Alternate enumeration based on traces as identified by Duplicate.v 
 *
 * This is a slight alteration to the above heuristic, ensuring that any
 * superblock will be contiguous in memory, while still following the original
 * heuristic
 *)

let super_blocks f joins =
  let blocks = ref [] in
  let visited = ref IntSet.empty in
  (* start_block:
       pc is the function entry point
          or a join point
          or the successor of a conditional test *)
  let rec start_block pc =
    let npc = P.to_int pc in
    if not (IntSet.mem npc !visited) then begin
      visited := IntSet.add npc !visited;
      in_block [] max_int pc
    end
  (* in_block: add pc to block and check successors *)
  and in_block blk minpc pc =
    let npc = P.to_int pc in
    let blk = pc :: blk in
    let minpc = min npc minpc in
    match PTree.get pc f.fn_code with
    | None -> assert false
    | Some b ->
       let rec do_instr_list = function
       | [] -> assert false
       | Lbranch s :: _ -> next_in_block blk minpc s
       | Ltailcall (sig0, ros) :: _ -> end_block blk minpc
       | Lcond (cond, args, ifso, ifnot, pred) :: _ -> begin
            match pred with
            | None -> (end_block blk minpc; start_block ifso; start_block ifnot)
            | Some true -> (next_in_block blk minpc ifso; start_block ifnot)
            | Some false -> (next_in_block blk minpc ifnot; start_block ifso)
          end
       | Ljumptable(arg, tbl) :: _ ->
             end_block blk minpc; List.iter start_block tbl
       | Lreturn :: _ -> end_block blk minpc
       | instr :: b' -> do_instr_list b' in
       do_instr_list b
  (* next_in_block: check if join point and either extend block
     or start block *)
  and next_in_block blk minpc pc =
    let npc = P.to_int pc in
    if IntSet.mem npc joins
    then (end_block blk minpc; start_block pc)
    else in_block blk minpc pc
  (* end_block: record block that we just discovered *)
  and end_block blk minpc =
    blocks := (minpc, List.rev blk) :: !blocks
  in
    start_block f.fn_entrypoint; !blocks

(* Build the enumeration *)

let enumerate_aux_sb f reach =
  flatten_blocks (super_blocks f (join_points f))

(**
 * Alternate enumeration based on traces as identified by Duplicate.v 
 *
 * This is a slight alteration to the above heuristic, ensuring that any
 * superblock will be contiguous in memory, while still following the original
 * heuristic
 *)

let get_some = function
| None -> failwith "Did not get some"
| Some thing -> thing

exception EmptyList

let rec last_element = function
  | [] -> raise EmptyList
  | e :: [] -> e
  | e' :: e :: l -> last_element (e::l)

let print_plist l =
  let rec f = function
  | [] -> ()
  | n :: l -> Printf.printf "%d, " (P.to_int n); f l
  in begin
    if !debug_flag then begin
      Printf.printf "[";
      f l;
      Printf.printf "]"
    end
  end

(* adapted from the above join_points function, but with PTree *)
let get_join_points code entry =
  let reached = ref (PTree.map (fun n i -> false) code) in
  let reached_twice = ref (PTree.map (fun n i -> false) code) in
  let rec traverse pc =
    if get_some @@ PTree.get pc !reached then begin
      if not (get_some @@ PTree.get pc !reached_twice) then
        reached_twice := PTree.set pc true !reached_twice
    end else begin
      reached := PTree.set pc true !reached;
      traverse_succs (successors_block @@ get_some @@ PTree.get pc code)
    end
  and traverse_succs = function
    | [] -> ()
    | [pc] -> traverse pc
    | pc :: l -> traverse pc; traverse_succs l
  in traverse entry; !reached_twice

let forward_sequences code entry =
  let visited = ref (PTree.map (fun n i -> false) code) in
  let join_points = get_join_points code entry in
  (* returns the list of traversed nodes, and a list of nodes to start traversing next *)
  let rec traverse_fallthrough code node =
    (* debug "Traversing %d..\n" (P.to_int node); *)
    if not (get_some @@ PTree.get node !visited) then begin
      visited := PTree.set node true !visited;
      match PTree.get node code with
      | None -> failwith "No such node"
      | Some bb ->
          let ln, rem = match (last_element bb) with
          | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
          | Lbuiltin _ -> assert false
          | Ltailcall _ | Lreturn -> begin (* debug "STOP tailcall/return\n"; *) ([], []) end
          | Lbranch n ->
              if get_some @@ PTree.get n join_points then ([], [n])
              else let ln, rem = traverse_fallthrough code n in (ln, rem)
          | Lcond (_, _, ifso, ifnot, info) -> (match info with
            | None -> begin (* debug "STOP Lcond None\n"; *) ([], [ifso; ifnot]) end
            | Some false ->
                if get_some @@ PTree.get ifnot join_points then ([], [ifso; ifnot])
                else let ln, rem = traverse_fallthrough code ifnot in (ln, [ifso] @ rem)
            | Some true ->
                if get_some @@ PTree.get ifso join_points then ([], [ifso; ifnot])
                else let ln, rem = traverse_fallthrough code ifso in (ln, [ifnot] @ rem)
            )
          | Ljumptable(_, ln) -> begin (* debug "STOP Ljumptable\n"; *) ([], ln) end
          in ([node] @ ln, rem)
      end
    else ([], [])
  in let rec f code = function
  | [] -> []
  | node :: ln ->
      let fs, rem_from_node = traverse_fallthrough code node
      in [fs] @ ((f code rem_from_node) @ (f code ln))
  in (f code [entry])

type pos = BinNums.positive

module PP = struct
  type t = pos * pos
  let compare a b =
    let ax, ay = a in
    let bx, by = b in
    let dx = compare ax bx in
    if (dx == 0) then compare ay by
    else dx
end

module PPMap = Map.Make(PP)

type vstate = Unvisited | Processed | Visited

let ppmap_is_true pp ppmap = PPMap.mem pp ppmap && PPMap.find pp ppmap

module Int = struct
  type t = int
  let compare x y = compare x y
end

module ISet = Set.Make(Int)

let print_iset s = begin
  if !debug_flag then begin
    Printf.printf "{";
    ISet.iter (fun e -> Printf.printf "%d, " e) s;
    Printf.printf "}"
  end
end

let print_sequence s =
  if !debug_flag then begin
    Printf.printf "[";
    List.iter (fun n -> Printf.printf "%d, " (P.to_int n)) s;
    Printf.printf "]\n"
  end

let print_ssequence ofs =
  if !debug_flag then begin
    Printf.printf "[";
    List.iter (fun s -> print_sequence s) ofs;
    Printf.printf "]\n"
  end

let rec minpc_of l =
  match l with
  | [] -> None
  | e::l -> begin
        let e_score = P.to_int e in
        let mpc = minpc_of l in 
        match mpc with
        | None -> Some e_score
        | Some e_score' -> if e_score < e_score' then Some e_score else Some e_score'
      end

let order_sequences code entry fs =
  let fs_a = Array.of_list fs in
  let fs_evaluated = Array.map (fun e -> false) fs_a in
  let ordered_fs = ref [] in
  let evaluate s_id =
    begin
      assert (not fs_evaluated.(s_id));
      ordered_fs := fs_a.(s_id) :: !ordered_fs;
      fs_evaluated.(s_id) <- true
    end
  in let choose_best_of candidates =
    let current_best_id = ref None in
    let current_best_score = ref None in
    begin
      List.iter (fun id ->
        match !current_best_id with
        | None -> begin
            current_best_id := Some id;
            match fs_a.(id) with
            | [] -> current_best_score := None
            | n::l -> current_best_score := Some (P.to_int n)
          end
        | Some b -> begin
            match fs_a.(id) with
            | [] -> ()
            | n::l -> let nscore = P.to_int n in
              match !current_best_score with
              | None -> (current_best_id := Some id; current_best_score := Some nscore)
              | Some bs -> if nscore > bs then (current_best_id := Some id; current_best_score := Some nscore)
          end
      ) candidates;
      !current_best_id
    end
  in let select_next () =
    let candidates = ref [] in
    begin
      Array.iteri (fun i _ ->
        begin
          if (not fs_evaluated.(i)) then
            candidates := i :: !candidates
        end
      ) fs_a;
      if not (List.length !candidates > 0) then begin
        Array.iteri (fun i _ ->
          if (not fs_evaluated.(i)) then candidates := i :: !candidates
        ) fs_a;
      end;
      get_some (choose_best_of !candidates)
    end
  in begin
    debug "-------------------------------\n";
    debug "forward sequences identified: "; print_ssequence fs;
    while List.length !ordered_fs != List.length fs do
      let next_id = select_next () in
      evaluate next_id
    done;
    debug "forward sequences ordered: "; print_ssequence (List.rev (!ordered_fs));
    List.rev (!ordered_fs)
  end

let enumerate_aux_trace f reach =
  let code = f.fn_code in
  let entry = f.fn_entrypoint in
  let fs = forward_sequences code entry in
  let ofs = order_sequences code entry fs in
  List.flatten ofs

let enumerate_aux f reach =
  if !Clflags.option_ftracelinearize then enumerate_aux_trace f reach
  else enumerate_aux_flat f reach