aboutsummaryrefslogtreecommitdiffstats
path: root/backend/Linearizeaux.ml
blob: 7aed5936728161f8de1015b808db74482d4d5540 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

(*                                                                     *)
(*              The Compcert verified compiler                         *)
(*                                                                     *)
(*          Xavier Leroy, INRIA Paris-Rocquencourt                     *)
(*                                                                     *)
(*  Copyright Institut National de Recherche en Informatique et en     *)
(*  Automatique.  All rights reserved.  This file is distributed       *)
(*  under the terms of the INRIA Non-Commercial License Agreement.     *)
(*                                                                     *)
(* *********************************************************************)

open LTL
open Maps

(* Trivial enumeration, in decreasing order of PC *)

(***
let enumerate_aux f reach =
  positive_rec
    Coq_nil
    (fun pc nodes ->
      if PMap.get pc reach
      then Coq_cons (pc, nodes)
      else nodes)
    f.fn_nextpc
***)

(* More clever enumeration that flattens basic blocks *)

open Camlcoq

module IntSet = Set.Make(struct type t = int let compare = compare end)

(* Determine join points: reachable nodes that have > 1 predecessor *)

let join_points f =
  let reached = ref IntSet.empty in
  let reached_twice = ref IntSet.empty in
  let rec traverse pc =
    let npc = P.to_int pc in
    if IntSet.mem npc !reached then begin
      if not (IntSet.mem npc !reached_twice) then
        reached_twice := IntSet.add npc !reached_twice
    end else begin
      reached := IntSet.add npc !reached;
      match PTree.get pc f.fn_code with
      | None -> ()
      | Some b -> traverse_succs (successors_block b)
    end
  and traverse_succs = function
    | [] -> ()
    | [pc] -> traverse pc
    | pc :: l -> traverse pc; traverse_succs l
  in traverse f.fn_entrypoint; !reached_twice

(* Cut into reachable basic blocks, annotated with the min value of the PC *)

let basic_blocks f joins =
  let blocks = ref [] in
  let visited = ref IntSet.empty in
  (* start_block:
       pc is the function entry point
          or a join point
          or the successor of a conditional test *)
  let rec start_block pc =
    let npc = P.to_int pc in
    if not (IntSet.mem npc !visited) then begin
      visited := IntSet.add npc !visited;
      in_block [] max_int pc
    end
  (* in_block: add pc to block and check successors *)
  and in_block blk minpc pc =
    let npc = P.to_int pc in
    let blk = pc :: blk in
    let minpc = min npc minpc in
    match PTree.get pc f.fn_code with
    | None -> assert false
    | Some b ->
       let rec do_instr_list = function
       | [] -> assert false
       | Lbranch s :: _ -> next_in_block blk minpc s
       | Ltailcall (sig0, ros) :: _ -> end_block blk minpc
       | Lcond (cond, args, ifso, ifnot) :: _ ->
             end_block blk minpc; start_block ifso; start_block ifnot
       | Ljumptable(arg, tbl) :: _ ->
             end_block blk minpc; List.iter start_block tbl
       | Lreturn :: _ -> end_block blk minpc
       | instr :: b' -> do_instr_list b' in
       do_instr_list b
  (* next_in_block: check if join point and either extend block
     or start block *)
  and next_in_block blk minpc pc =
    let npc = P.to_int pc in
    if IntSet.mem npc joins
    then (end_block blk minpc; start_block pc)
    else in_block blk minpc pc
  (* end_block: record block that we just discovered *)
  and end_block blk minpc =
    blocks := (minpc, List.rev blk) :: !blocks
  in
    start_block f.fn_entrypoint; !blocks

(* Flatten basic blocks in decreasing order of minpc *)

let flatten_blocks blks =
  let cmp_minpc (mpc1, _) (mpc2, _) =
    if mpc1 = mpc2 then 0 else if mpc1 > mpc2 then -1 else 1
  in
    List.flatten (List.map snd (List.sort cmp_minpc blks))

(* Build the enumeration *)

let enumerate_aux_flat f reach =
  flatten_blocks (basic_blocks f (join_points f))

(**
 * Enumeration based on traces as identified by Duplicate.v 
 *
 * The Duplicate phase heuristically identifies the most frequented paths. Each
 * Icond is modified so that the preferred condition is a fallthrough (ifnot)
 * rather than a branch (ifso).
 *
 * The enumeration below takes advantage of this - preferring to layout nodes
 * following the fallthroughs of the Lcond branches.
 *
 * It is slightly adapted from the work of Petris and Hansen 90 on intraprocedural
 * code positioning - only we do it on a broader grain, since we don't have the exact
 * frequencies (we only know which branch is the preferred one)
 *)

let get_some = function
| None -> failwith "Did not get some"
| Some thing -> thing

exception EmptyList

let rec last_element = function
  | [] -> raise EmptyList
  | e :: [] -> e
  | e' :: e :: l -> last_element (e::l)

let forward_sequences code entry =
  let visited = ref (PTree.map (fun n i -> false) code) in
  (* returns the list of traversed nodes, and a list of nodes to start traversing next *)
  let rec traverse_fallthrough code node =
    if not (get_some @@ PTree.get node !visited) then begin
      visited := PTree.set node true !visited;
      match PTree.get node code with
      | None -> failwith "No such node"
      | Some bb ->
          let ln, rem = match (last_element bb) with
          | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
          | Lbuiltin _ -> assert false
          | Ltailcall _ | Lreturn -> ([], [])
          | Lbranch n -> let ln, rem = traverse_fallthrough code n in (ln, rem)
          | Lcond (_, _, ifso, ifnot) -> let ln, rem = traverse_fallthrough code ifnot in (ln, [ifso] @ rem)
          | Ljumptable(_, ln) -> match ln with
              | [] -> ([], [])
              | n :: ln -> let lln, rem = traverse_fallthrough code n in (lln, ln @ rem)
          in ([node] @ ln, rem)
      end
    else ([], [])
  in let rec f code = function
  | [] -> []
  | node :: ln ->
      let fs, rem = traverse_fallthrough code node
      in [fs] @ (f code rem)
  in (f code [entry])

module PInt = struct
  type t = P.t
  let compare x y = compare (P.to_int x) (P.to_int y)
end

module PSet = Set.Make(PInt)

module LPInt = struct
  type t = P.t list
  let rec compare x y =
    match x with
    | [] -> ( match y with
      | [] -> 0
      | _ -> 1 )
    | e :: l -> match y with
      | [] -> -1
      | e' :: l' ->
          let e_cmp = PInt.compare e e' in
          if e_cmp == 0 then compare l l' else e_cmp
end

module LPSet = Set.Make(LPInt)

let iter_lpset f s = Seq.iter f (LPSet.to_seq s)

let first_of = function
  | [] -> None
  | e :: l -> Some e

let rec last_of = function
  | [] -> None
  | e :: l -> (match l with [] -> Some e | e :: l -> last_of l)

let can_be_merged code s s' =
  let last_s = get_some @@ last_of s in
  let first_s' = get_some @@ first_of s' in
  match get_some @@ PTree.get last_s code with
  | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
  | Lbuiltin _ | Ltailcall _ | Lreturn -> false
  | Lbranch n -> n == first_s'
  | Lcond (_, _, ifso, ifnot) -> ifnot == first_s'
  | Ljumptable (_, ln) ->
      match ln with
      | [] -> false
      | n :: ln -> n == first_s'

let merge s s' = Some s

let try_merge code (fs: (BinNums.positive list) list) =
  let seqs = ref (LPSet.of_list fs) in
  let oldLength = ref (LPSet.cardinal !seqs) in
  let continue = ref true in
  let found = ref false in
  while !continue do
    begin
      found := false;
      iter_lpset (fun s ->
        if !found then ()
        else iter_lpset (fun s' ->
          if (!found || s == s') then ()
          else if (can_be_merged code s s') then
            begin
              seqs := LPSet.remove s !seqs;
              seqs := LPSet.remove s' !seqs;
              seqs := LPSet.add (get_some (merge s s')) !seqs;
              found := true;
            end
          else ()
        ) !seqs
      ) !seqs;
      if !oldLength == LPSet.cardinal !seqs then
        continue := false
      else
        oldLength := LPSet.cardinal !seqs
    end
  done;
  !seqs

(** Code adapted from Duplicateaux.get_loop_headers
  *
  * Getting loop branches with a DFS visit :
  * Each node is either Unvisited, Visited, or Processed
  * pre-order: node becomes Processed
  * post-order: node becomes Visited
  *
  * If we come accross an edge to a Processed node, it's a loop!
  *)
type pos = BinNums.positive

module PP = struct
  type t = pos * pos
  let compare a b =
    let ax, ay = a in
    let bx, by = b in
    let dx = compare ax bx in
    if (dx == 0) then compare ay by
    else dx
end

module PPMap = Map.Make(PP)

type vstate = Unvisited | Processed | Visited

let get_loop_edges code entry =
  let visited = ref (PTree.map (fun n i -> Unvisited) code) in
  let is_loop_edge = ref PPMap.empty
  in let rec dfs_visit code from = function
  | [] -> ()
  | node :: ln ->
      match (get_some @@ PTree.get node !visited) with
      | Visited -> ()
      | Processed -> begin
          let from_node = get_some from in
          is_loop_edge := PPMap.add (from_node, node) true !is_loop_edge;
          visited := PTree.set node Visited !visited
        end
      | Unvisited -> begin
          visited := PTree.set node Processed !visited;
          let bb = get_some @@ PTree.get node code in
          let next_visits = (match (last_element bb) with
          | Lop _ | Lload _ | Lgetstack _ | Lsetstack _ | Lstore _ | Lcall _
          | Lbuiltin _ -> assert false
          | Ltailcall _ | Lreturn -> []
          | Lbranch n -> [n]
          | Lcond (_, _, ifso, ifnot) -> [ifso; ifnot]
          | Ljumptable(_, ln) -> ln
          ) in dfs_visit code (Some node) next_visits;
          visited := PTree.set node Visited !visited;
          dfs_visit code from ln
        end
  in begin
    dfs_visit code None [entry];
    !is_loop_edge
  end

let ppmap_is_true pp ppmap = PPMap.mem pp ppmap && PPMap.find pp ppmap

module Int = struct
  type t = int
  let compare x y = compare x y
end

module ISet = Set.Make(Int)

let print_iset s = begin
  Printf.printf "{";
  ISet.iter (fun e -> Printf.printf "%d, " e) s;
  Printf.printf "}"
end

let print_depmap dm = begin
  Printf.printf "[|";
  Array.iter (fun s -> print_iset s; Printf.printf ", ") dm;
  Printf.printf "|]\n"
end

let construct_depmap code entry fs =
  let is_loop_edge = get_loop_edges code entry in
  let visited = ref (PTree.map (fun n i -> false) code) in
  let depmap = Array.map (fun e -> ISet.empty) fs in
  let find_index_of_node n =
    let index = ref 0 in
    begin
      Array.iteri (fun i s ->
        match List.find_opt (fun e -> e == n) s with
        | Some _ -> index := i
        | None -> ()
      ) fs;
      !index
    end
  in let rec dfs_visit code = function
  | [] -> ()
  | node :: ln ->
      match (get_some @@ PTree.get node !visited) with
      | true -> ()
      | false -> begin
          visited := PTree.set node true !visited;
          let bb = get_some @@ PTree.get node code in
          let next_visits =
            match (last_element bb) with
            | Ltailcall _ | Lreturn -> []
            | Lbranch n -> [n]
            | Lcond (_, _, ifso, ifnot) -> begin
                (if not (ppmap_is_true (node, ifso) is_loop_edge) then
                  let in_index_fs = find_index_of_node node in
                  let out_index_fs = find_index_of_node ifso in
                  depmap.(out_index_fs) <- ISet.add in_index_fs depmap.(out_index_fs)
                else
                  ());
                [ifso; ifnot]
              end
            | Ljumptable(_, ln) -> begin
                let in_index_fs = find_index_of_node node in
                List.iter (fun n ->
                  if not (ppmap_is_true (node, n) is_loop_edge) then
                    let out_index_fs = find_index_of_node n in
                    depmap.(out_index_fs) <- ISet.add in_index_fs depmap.(out_index_fs)
                  else
                    ()
                ) ln;
                ln
              end
            (* end of bblocks should not be another value than one of the above *)
            | _ -> failwith "last_element gave an invalid output"
          in dfs_visit code next_visits
        end
  in begin
    dfs_visit code [entry];
    depmap
  end

let print_sequence s =
  Printf.printf "[";
  List.iter (fun n -> Printf.printf "%d, " (P.to_int n)) s;
  Printf.printf "]\n"

let print_ssequence ofs =
  Printf.printf "[";
  List.iter (fun s -> print_sequence s) ofs;
  Printf.printf "]\n"

let order_sequences code entry fs =
  let fs_a = Array.of_list fs in
  let depmap = construct_depmap code entry fs_a in
  let fs_evaluated = Array.map (fun e -> false) fs_a in
  let ordered_fs = ref [] in
  let evaluate s_id =
    begin
      assert (not fs_evaluated.(s_id));
      ordered_fs := fs_a.(s_id) :: !ordered_fs;
      fs_evaluated.(s_id) <- true;
      Array.iteri (fun i deps ->
        depmap.(i) <- ISet.remove s_id deps
      ) depmap
    end
  in let select_next () =
    let selected_id = ref (-1) in
    begin
      Array.iteri (fun i deps ->
        begin
          (* Printf.printf "Deps: "; print_iset deps; Printf.printf "\n"; *)
          if !selected_id == -1 && deps == ISet.empty && not fs_evaluated.(i)
          then selected_id := i
        end
      ) depmap;
      !selected_id
    end
  in begin
    (* Printf.printf "depmap: "; print_depmap depmap; *)
    print_ssequence fs;
    while List.length !ordered_fs != List.length fs do
      let next_id = select_next () in
      evaluate next_id
    done;
    (* print_ssequence (List.rev (!ordered_fs)); *)
    List.rev (!ordered_fs)
  end

let enumerate_aux_trace f reach =
  let code = f.fn_code in
  let entry = f.fn_entrypoint in
  let fs = forward_sequences code entry in
  let ofs = order_sequences code entry fs in
  List.flatten ofs

let enumerate_aux f reach =
  if !Clflags.option_ftracelinearize then enumerate_aux_trace f reach
  else enumerate_aux_flat f reach